diff --git a/README.md b/README.md index 8432455..05a2a5b 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,9 @@ verifies that test files are named correctly. #### `no-commit-to-branch` Protect specific branches from direct checkins. - Use `args: [--branch, staging, --branch, main]` to set the branch. - Both `main` and `master` are protected by default if no branch argument is set. + If no branch argument is set, the hook auto-detects the repository's default + branch from `origin/HEAD`. Falls back to protecting both `main` and `master` + if `origin/HEAD` is not configured. - `-b` / `--branch` may be specified multiple times to protect multiple branches. - `-p` / `--pattern` can be used to protect branches that match a supplied regex diff --git a/pre_commit_hooks/no_commit_to_branch.py b/pre_commit_hooks/no_commit_to_branch.py index b0b8b23..c5edf21 100644 --- a/pre_commit_hooks/no_commit_to_branch.py +++ b/pre_commit_hooks/no_commit_to_branch.py @@ -9,6 +9,17 @@ from pre_commit_hooks.util import CalledProcessError from pre_commit_hooks.util import cmd_output +def _default_branch() -> frozenset[str]: + try: + ref = cmd_output('git', 'rev-parse', '--abbrev-ref', 'origin/HEAD') + branch = ref.strip().removeprefix('origin/') + if branch: + return frozenset((branch,)) + except CalledProcessError: + pass + return frozenset(('master', 'main')) + + def is_on_branch( protected: AbstractSet[str], patterns: AbstractSet[str] = frozenset(), @@ -39,7 +50,7 @@ def main(argv: Sequence[str] | None = None) -> int: ) args = parser.parse_args(argv) - protected = frozenset(args.branch or ('master', 'main')) + protected = frozenset(args.branch) if args.branch else _default_branch() patterns = frozenset(args.pattern or ()) return int(is_on_branch(protected, patterns)) diff --git a/tests/no_commit_to_branch_test.py b/tests/no_commit_to_branch_test.py index 7d37e49..ccd83e4 100644 --- a/tests/no_commit_to_branch_test.py +++ b/tests/no_commit_to_branch_test.py @@ -1,7 +1,10 @@ from __future__ import annotations +from unittest.mock import patch + import pytest +from pre_commit_hooks.no_commit_to_branch import _default_branch from pre_commit_hooks.no_commit_to_branch import is_on_branch from pre_commit_hooks.no_commit_to_branch import main from pre_commit_hooks.util import cmd_output @@ -78,3 +81,43 @@ def test_default_branch_names(temp_git_dir, branch_name): with temp_git_dir.as_cwd(): cmd_output('git', 'checkout', '-b', branch_name) assert main(()) == 1 + + +def test_default_branch_falls_back_when_empty_branch(tmpdir): + with patch( + 'pre_commit_hooks.no_commit_to_branch.cmd_output', + return_value='origin/', + ): + assert _default_branch() == frozenset({'master', 'main'}) + + +def test_default_branch_detects_from_origin(tmpdir): + remote = tmpdir.join('remote') + cmd_output('git', 'init', '--', str(remote)) + with remote.as_cwd(): + cmd_output('git', 'checkout', '-b', 'develop') + git_commit('--allow-empty', '-m', 'init') + + local = tmpdir.join('local') + cmd_output('git', 'clone', str(remote), str(local)) + + with local.as_cwd(): + assert _default_branch() == frozenset({'develop'}) + + +def test_main_blocks_detected_default_branch(tmpdir): + remote = tmpdir.join('remote') + cmd_output('git', 'init', '--', str(remote)) + with remote.as_cwd(): + cmd_output('git', 'checkout', '-b', 'develop') + git_commit('--allow-empty', '-m', 'init') + + local = tmpdir.join('local') + cmd_output('git', 'clone', str(remote), str(local)) + + with local.as_cwd(): + # On detected default branch — should be blocked + assert main(()) == 1 + # On a feature branch — should pass + cmd_output('git', 'checkout', '-b', 'feature') + assert main(()) == 0