diff --git a/reusable_workflows/repo_policies/bot_checks/check_bot_approved_files.py b/reusable_workflows/repo_policies/bot_checks/check_bot_approved_files.py new file mode 100644 index 0000000..4d72ce6 --- /dev/null +++ b/reusable_workflows/repo_policies/bot_checks/check_bot_approved_files.py @@ -0,0 +1,97 @@ +import json +import subprocess + +import github3 + +from check_membership.check_membership import is_approved_bot +from shared.utils import download_gh_file, load_env_vars + +BOT_APPROVED_FILES_PATH = ".github/repo_policies/bot_approved_files.json" +REQUIRED_ENV_VARS = [ + "USER", + "GH_TOKEN", + "GH_ORG", + "REPO", + "MERGE_BASE_SHA", + "BRANCH_HEAD_SHA", +] + + +def get_changed_files(merge_base_sha: str, branch_head_sha: str) -> list[str]: + """ + Compares the files changed in the current branch to the merge base. + """ + commit_range = f"{merge_base_sha}..{branch_head_sha}" + result = subprocess.run( + ["git", "diff", "--name-only", commit_range], stdout=subprocess.PIPE, text=True + ) + changed_files = result.stdout.strip().split("\n") + return changed_files + + +def get_approved_files_config(repo: github3.github.repo) -> str: + """ + Loads the config from the repository that contains the list of approved files. + """ + try: + config_file = download_gh_file(repo, BOT_APPROVED_FILES_PATH) + return config_file + except github3.exceptions.NotFoundError: + raise Exception( + f"No config file found. Make sure you have a file saved at {BOT_APPROVED_FILES_PATH}" + ) + + +def get_approved_files(config_file: str) -> list[str]: + """ + Extracts the list of approved files from the config file. + """ + try: + config = json.loads(config_file) + except json.JSONDecodeError: + raise Exception("Config file is not a valid JSON file") + try: + approved_files = config["approved_files"] + except KeyError: + raise Exception("No approved_files key found in config file") + + if len(approved_files) == 0: + raise Exception("No approved files found in config file") + return approved_files + + +def pr_is_blocked(env_vars: dict) -> bool: + """ + Logic to check if the Bot's PR can be merged or should be blocked. + """ + gh = github3.login(token=env_vars["GH_TOKEN"]) + repo = gh.repository(owner=env_vars["GH_ORG"], repository=env_vars["REPO"]) + changed_files = get_changed_files( + env_vars["MERGE_BASE_SHA"], env_vars["BRANCH_HEAD_SHA"] + ) + config = get_approved_files_config(repo) + approved_files = get_approved_files(config) + block_pr = not all(file in approved_files for file in changed_files) + return block_pr + + +def main() -> None: + env_vars = load_env_vars(REQUIRED_ENV_VARS) + user = env_vars["USER"] + + is_bot = is_approved_bot(user) + + if is_bot: + block_pr = pr_is_blocked(env_vars) + + else: + print( + f"{user} is not an approved bot. Letting CLA check handle contribution decision." + ) + block_pr = False + + subprocess.run(f"""echo 'block_pr={block_pr}' >> $GITHUB_OUTPUT""", shell=True) + + +if __name__ == "__main__": + main() diff --git a/reusable_workflows/shared/utils.py b/reusable_workflows/shared/utils.py index b550561..ff18378 100644 --- a/reusable_workflows/shared/utils.py +++ b/reusable_workflows/shared/utils.py @@ -1,10 +1,13 @@ +import os import time import github3 def download_gh_file(repo: github3.github.repo, file_path: str) -> str: - # sometimes the request does not work the first time, so set a retry + """ + Handles the download of a file from a GitHub repository. Retries 4 times if the request fails. + """ for attempt in range(5): try: file_content = repo.file_contents(file_path) @@ -22,3 +25,16 @@ def download_gh_file(repo: github3.github.repo, file_path: str) -> str: file_decoded = file_content.decoded.decode() return file_decoded + + +def load_env_vars(var_names: list[str]) -> dict: + """ + Loads required env vars and returns them as a dictionary. + """ + env_vars = {} + for var in var_names: + try: + env_vars[var] = os.environ[var] + except KeyError: + raise Exception(f"Environment variable '{var}' is not set.") + return env_vars diff --git a/reusable_workflows/tests/test_repo_policies.py b/reusable_workflows/tests/test_repo_policies.py new file mode 100644 index 0000000..6636517 --- /dev/null +++ b/reusable_workflows/tests/test_repo_policies.py @@ -0,0 +1,177 @@ +import subprocess +from unittest import mock + +import github3 +import pytest + +from repo_policies.bot_checks.check_bot_approved_files import ( + BOT_APPROVED_FILES_PATH, + get_approved_files, + get_approved_files_config, + get_changed_files, + main, + pr_is_blocked, +) + + +@mock.patch("subprocess.run") +def test_get_changed_files(mock_subprocess_run): + mock_subprocess_run.return_value = mock.Mock(stdout="file1.py\nfile2.py\n") + + changed_files = get_changed_files("merge_base_sha", "branch_head_sha") + + assert changed_files == ["file1.py", "file2.py"] + mock_subprocess_run.assert_called_once_with( + ["git", "diff", "--name-only", "merge_base_sha..branch_head_sha"], + stdout=subprocess.PIPE, + text=True, + ) + + +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.download_gh_file") +def test_get_approved_files_config(download_gh_file): + repo = mock.Mock() + config_file_mock = mock.Mock() + download_gh_file.return_value = config_file_mock + + config_file = get_approved_files_config(repo) + + download_gh_file.assert_called_once_with(repo, BOT_APPROVED_FILES_PATH) + assert config_file == config_file_mock + + +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.download_gh_file") +def test_get_approved_files_config_fails(download_gh_file): + repo = mock.Mock() + download_gh_file.side_effect = github3.exceptions.NotFoundError(mock.Mock()) + + with pytest.raises(Exception) as exc: + get_approved_files_config(repo) + + assert ( + # fmt: off + str(exc.value) == f"No config file found. Make sure you have a file saved at {BOT_APPROVED_FILES_PATH}" + ) + + +def test_get_approved_files(): + config_file = '{"approved_files": ["file1.py", "file2.py"]}' + approved_files = get_approved_files(config_file) + + assert approved_files == ["file1.py", "file2.py"] + + +def test_get_approved_files_not_json(): + config_file = "not a json file" + + with pytest.raises(Exception) as exc: + get_approved_files(config_file) + + assert str(exc.value) == "Config file is not a valid JSON file" + + +def test_get_approved_files_no_approved_files(): + config_file = '{"another_key": ["file1.py", "file2.py"]}' + + with pytest.raises(Exception) as exc: + get_approved_files(config_file) + + assert str(exc.value) == "No approved_files key found in config file" + + +def test_get_approved_files_no_files(): + config_file = '{"approved_files": []}' + + with pytest.raises(Exception) as exc: + get_approved_files(config_file) + + assert str(exc.value) == "No approved files found in config file" + + +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.get_changed_files") +@mock.patch( + "repo_policies.bot_checks.check_bot_approved_files.get_approved_files_config" +) +@mock.patch("github3.login") +def test_pr_is_blocked_false(gh_login, get_approved_files_config, get_changed_files): + env_vars = { + "GH_TOKEN": "token", + "GH_ORG": "org", + "REPO": "repo", + "MERGE_BASE_SHA": "base", + "BRANCH_HEAD_SHA": "head", + } + gh = mock.Mock() + gh_login.return_value = gh + repo = mock.Mock() + gh.repository.return_value = repo + get_changed_files.return_value = ["file1.py", "file2.py"] + get_approved_files_config.return_value = ( + '{"approved_files": ["file1.py", "file2.py", "file3.py"]}' + ) + + blocked = pr_is_blocked(env_vars) + + assert blocked is False + get_changed_files.assert_called_once_with("base", "head") + get_approved_files_config.assert_called_once_with(repo) + + +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.get_changed_files") +@mock.patch( + "repo_policies.bot_checks.check_bot_approved_files.get_approved_files_config" +) +@mock.patch("github3.login") +def test_pr_is_blocked_true(gh_login, get_approved_files_config, get_changed_files): + env_vars = { + "GH_TOKEN": "token", + "GH_ORG": "org", + "REPO": "repo", + "MERGE_BASE_SHA": "base", + "BRANCH_HEAD_SHA": "head", + } + gh = mock.Mock() + gh_login.return_value = gh + repo = mock.Mock() + gh.repository.return_value = repo + get_changed_files.return_value = ["file1.py", "file2.py"] + get_approved_files_config.return_value = '{"approved_files": ["file1.py"]}' + + blocked = pr_is_blocked(env_vars) + + assert blocked is True + get_changed_files.assert_called_once_with("base", "head") + get_approved_files_config.assert_called_once_with(repo) + + +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.load_env_vars") +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.is_approved_bot") +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.pr_is_blocked") +@mock.patch("subprocess.run") +def test_main_succeeds(subprocess_run, pr_is_blocked, is_approved_bot, load_env_vars): + env_vars = {"GH_TOKEN": "token", "USER": "user"} + load_env_vars.return_value = env_vars + is_approved_bot.return_value = True + pr_is_blocked.return_value = False + + main() + + subprocess_run.assert_called_once_with( + "echo 'block_pr=False' >> $GITHUB_OUTPUT", shell=True + ) + +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.load_env_vars") +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.is_approved_bot") +@mock.patch("repo_policies.bot_checks.check_bot_approved_files.pr_is_blocked") +@mock.patch("subprocess.run") +def test_main_not_a_bot(subprocess_run, pr_is_blocked, is_approved_bot, load_env_vars): + env_vars = {"GH_TOKEN": "token", "USER": "user"} + load_env_vars.return_value = env_vars + is_approved_bot.return_value = False + + main() + + subprocess_run.assert_called_once_with( + "echo 'block_pr=False' >> $GITHUB_OUTPUT", shell=True + ) + pr_is_blocked.assert_not_called() diff --git a/reusable_workflows/tests/test_utils.py b/reusable_workflows/tests/test_utils.py index d17a342..3ae2aaf 100644 --- a/reusable_workflows/tests/test_utils.py +++ b/reusable_workflows/tests/test_utils.py @@ -1,8 +1,9 @@ +import os from unittest import mock import pytest -from shared.utils import download_gh_file +from shared.utils import download_gh_file, load_env_vars def test_download_file_succeeds_first_try(): @@ -46,3 +47,18 @@ def test_download_file_fails(mock_get): assert repo.file_contents.call_count == 5 file_content_obj.decoded.assert_not_called + + +@mock.patch.dict(os.environ, {"REPO": "repo-1", "GH_TOKEN": "token"}) +def test_load_env_vars_succeeds(capfd): + env_vars = load_env_vars(["REPO", "GH_TOKEN"]) + + assert env_vars == {"REPO": "repo-1", "GH_TOKEN": "token"} + + +@mock.patch.dict(os.environ, {"REPO": "repo-1"}, clear=True) +def test_load_env_vars_fails(capfd): + with pytest.raises(Exception) as exc: + env_vars = load_env_vars(["REPO", "GH_TOKEN"]) + print(env_vars) + assert str(exc.value) == "Environment variable 'GH_TOKEN' is not set."