Refactor catch_dotenv hook and tests for improved readability and consistency

This commit is contained in:
Chris Rowe 2025-08-28 20:53:45 -06:00
parent 989ac68f29
commit c7f0dae9a4
No known key found for this signature in database
3 changed files with 259 additions and 98 deletions

View file

@ -6,16 +6,16 @@ import os
import re
import sys
import tempfile
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Iterable
# Defaults / constants
DEFAULT_ENV_FILE = ".env"
DEFAULT_GITIGNORE_FILE = ".gitignore"
DEFAULT_EXAMPLE_ENV_FILE = ".env.example"
GITIGNORE_BANNER = "# Added by pre-commit hook to prevent committing secrets"
DEFAULT_ENV_FILE = '.env'
DEFAULT_GITIGNORE_FILE = '.gitignore'
DEFAULT_EXAMPLE_ENV_FILE = '.env.example'
GITIGNORE_BANNER = '# Added by pre-commit hook to prevent committing secrets'
_KEY_REGEX = re.compile(r"^\s*(?:export\s+)?([A-Za-z_][A-Za-z0-9_]*)\s*=")
_KEY_REGEX = re.compile(r'^\s*(?:export\s+)?([A-Za-z_][A-Za-z0-9_]*)\s*=')
def _atomic_write(path: str, data: str) -> None:
@ -28,16 +28,16 @@ def _atomic_write(path: str, data: str) -> None:
parallel (tests exercise concurrent normalization). Keeping this helper
local avoids adding any dependency.
"""
fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(path) or ".")
fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(path) or '.')
try:
with os.fdopen(fd, "w", encoding="utf-8", newline="") as tmp_f:
with os.fdopen(fd, 'w', encoding='utf-8', newline='') as tmp_f:
tmp_f.write(data)
os.replace(tmp_path, path)
finally: # Clean up if replace failed
if os.path.exists(tmp_path): # (rare failure case)
try:
os.remove(tmp_path)
except OSError:
except OSError:
pass
@ -45,38 +45,51 @@ def _read_gitignore(gitignore_file: str) -> tuple[str, list[str]]:
"""Read and parse .gitignore file content."""
try:
if os.path.exists(gitignore_file):
with open(gitignore_file, "r", encoding="utf-8") as f:
with open(gitignore_file, encoding='utf-8') as f:
original_text = f.read()
lines = original_text.splitlines()
else:
original_text = ""
original_text = ''
lines = []
except OSError as exc:
print(f"ERROR: unable to read {gitignore_file}: {exc}", file=sys.stderr)
print(
f"ERROR: unable to read {gitignore_file}: {exc}",
file=sys.stderr,
)
raise
return original_text if lines else "", lines
return original_text if lines else '', lines
def _normalize_gitignore_lines(lines: list[str], env_file: str, banner: str) -> list[str]:
"""Normalize .gitignore lines by removing duplicates and adding canonical tail."""
def _normalize_gitignore_lines(
lines: list[str],
env_file: str,
banner: str,
) -> list[str]:
"""Normalize .gitignore lines by removing duplicates and canonical tail."""
# Trim trailing blank lines
while lines and not lines[-1].strip():
lines.pop()
# Remove existing occurrences
filtered: list[str] = [ln for ln in lines if ln.strip() not in {env_file, banner}]
filtered: list[str] = [
ln for ln in lines if ln.strip() not in {env_file, banner}
]
if filtered and filtered[-1].strip():
filtered.append("") # ensure single blank before banner
filtered.append('') # ensure single blank before banner
elif not filtered: # empty file -> still separate section visually
filtered.append("")
filtered.append('')
filtered.append(banner)
filtered.append(env_file)
return filtered
def ensure_env_in_gitignore(env_file: str, gitignore_file: str, banner: str) -> bool:
def ensure_env_in_gitignore(
env_file: str,
gitignore_file: str,
banner: str,
) -> bool:
"""Ensure canonical banner + env tail in .gitignore.
Returns True only when the file content was changed. Returns False both
@ -89,27 +102,30 @@ def ensure_env_in_gitignore(env_file: str, gitignore_file: str, banner: str) ->
return False
filtered = _normalize_gitignore_lines(lines, env_file, banner)
new_content = "\n".join(filtered) + "\n"
new_content = '\n'.join(filtered) + '\n'
# Normalize original content to a single trailing newline for comparison
normalized_original = original_content_str
if normalized_original and not normalized_original.endswith("\n"):
normalized_original += "\n"
if normalized_original and not normalized_original.endswith('\n'):
normalized_original += '\n'
if new_content == normalized_original:
return False
try:
_atomic_write(gitignore_file, new_content)
return True
except OSError as exc:
print(f"ERROR: unable to write {gitignore_file}: {exc}", file=sys.stderr)
except OSError as exc:
print(
f"ERROR: unable to write {gitignore_file}: {exc}",
file=sys.stderr,
)
return False
def create_example_env(src_env: str, example_file: str) -> bool:
"""Generate .env.example with unique KEY= lines (no values)."""
try:
with open(src_env, "r", encoding="utf-8") as f_env:
with open(src_env, encoding='utf-8') as f_env:
lines = f_env.readlines()
except OSError as exc:
print(f"ERROR: unable to read {src_env}: {exc}", file=sys.stderr)
@ -136,33 +152,51 @@ def create_example_env(src_env: str, example_file: str) -> bool:
]
body = [f"{k}=" for k in keys]
try:
_atomic_write(example_file, "\n".join(header + body) + "\n")
_atomic_write(example_file, '\n'.join(header + body) + '\n')
return True
except OSError as exc: # pragma: no cover
print(f"ERROR: unable to write '{example_file}': {exc}", file=sys.stderr)
print(
f"ERROR: unable to write '{example_file}': {exc}",
file=sys.stderr,
)
return False
def _has_env(filenames: Iterable[str], env_file: str) -> bool:
"""Return True if any staged path refers to a target env file by basename."""
"""Return True if any staged path refers to target env file by basename."""
return any(os.path.basename(name) == env_file for name in filenames)
def _print_failure(env_file: str, gitignore_file: str, example_created: bool, gitignore_modified: bool) -> None:
def _print_failure(
env_file: str,
gitignore_file: str,
example_created: bool,
gitignore_modified: bool,
) -> None:
# Match typical hook output style: one short line per action.
print(f"Blocked committing {env_file}.")
if gitignore_modified:
print(f"Updated {gitignore_file}.")
if example_created:
print("Generated .env.example.")
print('Generated .env.example.')
print(f"Remove {env_file} from the commit and retry.")
def main(argv: Sequence[str] | None = None) -> int:
"""Hook entry-point."""
parser = argparse.ArgumentParser(description="Blocks committing .env files.")
parser.add_argument('filenames', nargs='*', help='Staged filenames (supplied by pre-commit).')
parser.add_argument('--create-example', action='store_true', help='Generate example env file (.env.example).')
parser = argparse.ArgumentParser(
description='Blocks committing .env files.',
)
parser.add_argument(
'filenames',
nargs='*',
help='Staged filenames (supplied by pre-commit).',
)
parser.add_argument(
'--create-example',
action='store_true',
help='Generate example env file (.env.example).',
)
args = parser.parse_args(argv)
env_file = DEFAULT_ENV_FILE
# Use current working directory as repository root (pre-commit executes
@ -175,14 +209,26 @@ def main(argv: Sequence[str] | None = None) -> int:
if not _has_env(args.filenames, env_file):
return 0
gitignore_modified = ensure_env_in_gitignore(env_file, gitignore_file, GITIGNORE_BANNER)
gitignore_modified = ensure_env_in_gitignore(
env_file,
gitignore_file,
GITIGNORE_BANNER,
)
example_created = False
if args.create_example:
# Source env is always looked up relative to repo root
if os.path.exists(env_abspath):
example_created = create_example_env(env_abspath, example_file)
example_created = create_example_env(
env_abspath,
example_file,
)
_print_failure(env_file, gitignore_file, example_created, gitignore_modified)
_print_failure(
env_file,
gitignore_file,
example_created,
gitignore_modified,
)
return 1 # Block commit