diff --git a/pre_commit_hooks/detect_aws_credentials.py b/pre_commit_hooks/detect_aws_credentials.py index 4f59d9c..07988f0 100644 --- a/pre_commit_hooks/detect_aws_credentials.py +++ b/pre_commit_hooks/detect_aws_credentials.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import configparser import os +import re from typing import NamedTuple from typing import Sequence @@ -89,6 +90,11 @@ def check_file_for_aws_keys( return bad_files +def filter_keys(keys: set[str], exclude: str) -> set[str]: + pattern = re.compile(exclude) + return {key for key in keys if not pattern.match(key)} + + def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='+', help='Filenames to run') @@ -110,6 +116,12 @@ def main(argv: Sequence[str] | None = None) -> int: action='store_true', help='Allow hook to pass when no credentials are detected.', ) + parser.add_argument( + '--exclude-values', + dest='exclude_values', + default='^$', + help='Regular expression for secret values that should be excluded.', + ) args = parser.parse_args(argv) credential_files = set(args.credentials_file) @@ -137,7 +149,7 @@ def main(argv: Sequence[str] | None = None) -> int: ) return 2 - keys_b = {key.encode() for key in keys} + keys_b = {key.encode() for key in filter_keys(keys, args.exclude_values)} bad_filenames = check_file_for_aws_keys(args.filenames, keys_b) if bad_filenames: for bad_file in bad_filenames: diff --git a/tests/detect_aws_credentials_test.py b/tests/detect_aws_credentials_test.py index afda47a..da39e06 100644 --- a/tests/detect_aws_credentials_test.py +++ b/tests/detect_aws_credentials_test.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest +from pre_commit_hooks.detect_aws_credentials import filter_keys from pre_commit_hooks.detect_aws_credentials import get_aws_cred_files_from_env from pre_commit_hooks.detect_aws_credentials import get_aws_secrets_from_env from pre_commit_hooks.detect_aws_credentials import get_aws_secrets_from_file @@ -101,6 +102,18 @@ def test_get_aws_secrets_from_file(filename, expected_keys): assert keys == expected_keys +@pytest.mark.parametrize( + ('keys', 'exclude', 'expected_ret'), + ( + ({'foo', 'bar'}, '^$', {'foo', 'bar'}), + ({'foo', 'bar', 'baz'}, '^bar$', {'foo', 'baz'}), + ({'foo', 'bar', 'baz'}, '^ba', {'foo'}), + ), +) +def test_filter_keys(keys, exclude, expected_ret): + assert filter_keys(keys, exclude) == expected_ret + + @pytest.mark.parametrize( ('filename', 'expected_retval'), ( @@ -123,6 +136,26 @@ def test_detect_aws_credentials(filename, expected_retval): assert ret == expected_retval +@pytest.mark.parametrize( + ('filename', 'exclude', 'expected_retval'), + ( + ('aws_config_with_secret.ini', '.*rpg.*', 0), + ('aws_config_with_multiple_sections.ini', '.*rpg.*', 1), + ), +) +def test_detect_aws_credentials_with_exclude( + filename, exclude, expected_retval, +): + ret = main(( + get_resource_path(filename), + '--exclude-values', + exclude, + '--credentials-file', + 'testing/resources/aws_config_with_multiple_sections.ini', + )) + assert ret == expected_retval + + def test_allows_arbitrarily_encoded_files(tmpdir): src_ini = tmpdir.join('src.ini') src_ini.write(