Skip to content

Commit

Permalink
feat(IDX): create bot check code (#70)
Browse files Browse the repository at this point in the history
* wip

* wip

* add file

* feat(IDX): create bot check code

* lint

* remove workflow code

* add method descriptions

* add test

* update descriptions
  • Loading branch information
cgundy authored Dec 4, 2024
1 parent 9f79bf6 commit 827d94c
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
import subprocess

import github3

from check_membership.check_membership import is_approved_bot
from shared.utils import download_gh_file, load_env_vars

BOT_APPROVED_FILES_PATH = ".github/repo_policies/bot_approved_files.json"
REQUIRED_ENV_VARS = [
"USER",
"GH_TOKEN",
"GH_ORG",
"REPO",
"MERGE_BASE_SHA",
"BRANCH_HEAD_SHA",
]


def get_changed_files(merge_base_sha: str, branch_head_sha: str) -> list[str]:
"""
Compares the files changed in the current branch to the merge base.
"""
commit_range = f"{merge_base_sha}..{branch_head_sha}"
result = subprocess.run(
["git", "diff", "--name-only", commit_range], stdout=subprocess.PIPE, text=True
)
changed_files = result.stdout.strip().split("\n")
return changed_files


def get_approved_files_config(repo: github3.github.repo) -> str:
"""
Loads the config from the repository that contains the list of approved files.
"""
try:
config_file = download_gh_file(repo, BOT_APPROVED_FILES_PATH)
return config_file
except github3.exceptions.NotFoundError:
raise Exception(
f"No config file found. Make sure you have a file saved at {BOT_APPROVED_FILES_PATH}"
)


def get_approved_files(config_file: str) -> list[str]:
"""
Extracts the list of approved files from the config file.
"""
try:
config = json.loads(config_file)
except json.JSONDecodeError:
raise Exception("Config file is not a valid JSON file")
try:
approved_files = config["approved_files"]
except KeyError:
raise Exception("No approved_files key found in config file")

if len(approved_files) == 0:
raise Exception("No approved files found in config file")
return approved_files


def pr_is_blocked(env_vars: dict) -> bool:
"""
Logic to check if the Bot's PR can be merged or should be blocked.
"""
gh = github3.login(token=env_vars["GH_TOKEN"])
repo = gh.repository(owner=env_vars["GH_ORG"], repository=env_vars["REPO"])
changed_files = get_changed_files(
env_vars["MERGE_BASE_SHA"], env_vars["BRANCH_HEAD_SHA"]
)
config = get_approved_files_config(repo)
approved_files = get_approved_files(config)
block_pr = not all(file in approved_files for file in changed_files)
return block_pr


def main() -> None:
env_vars = load_env_vars(REQUIRED_ENV_VARS)
user = env_vars["USER"]

is_bot = is_approved_bot(user)

if is_bot:
block_pr = pr_is_blocked(env_vars)

else:
print(
f"{user} is not an approved bot. Letting CLA check handle contribution decision."
)
block_pr = False

subprocess.run(f"""echo 'block_pr={block_pr}' >> $GITHUB_OUTPUT""", shell=True)


if __name__ == "__main__":
main()
18 changes: 17 additions & 1 deletion reusable_workflows/shared/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import time

import github3


def download_gh_file(repo: github3.github.repo, file_path: str) -> str:
# sometimes the request does not work the first time, so set a retry
"""
Handles the download of a file from a GitHub repository. Retries 4 times if the request fails.
"""
for attempt in range(5):
try:
file_content = repo.file_contents(file_path)
Expand All @@ -22,3 +25,16 @@ def download_gh_file(repo: github3.github.repo, file_path: str) -> str:

file_decoded = file_content.decoded.decode()
return file_decoded


def load_env_vars(var_names: list[str]) -> dict:
"""
Loads required env vars and returns them as a dictionary.
"""
env_vars = {}
for var in var_names:
try:
env_vars[var] = os.environ[var]
except KeyError:
raise Exception(f"Environment variable '{var}' is not set.")
return env_vars
177 changes: 177 additions & 0 deletions reusable_workflows/tests/test_repo_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import subprocess
from unittest import mock

import github3
import pytest

from repo_policies.bot_checks.check_bot_approved_files import (
BOT_APPROVED_FILES_PATH,
get_approved_files,
get_approved_files_config,
get_changed_files,
main,
pr_is_blocked,
)


@mock.patch("subprocess.run")
def test_get_changed_files(mock_subprocess_run):
mock_subprocess_run.return_value = mock.Mock(stdout="file1.py\nfile2.py\n")

changed_files = get_changed_files("merge_base_sha", "branch_head_sha")

assert changed_files == ["file1.py", "file2.py"]
mock_subprocess_run.assert_called_once_with(
["git", "diff", "--name-only", "merge_base_sha..branch_head_sha"],
stdout=subprocess.PIPE,
text=True,
)


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.download_gh_file")
def test_get_approved_files_config(download_gh_file):
repo = mock.Mock()
config_file_mock = mock.Mock()
download_gh_file.return_value = config_file_mock

config_file = get_approved_files_config(repo)

download_gh_file.assert_called_once_with(repo, BOT_APPROVED_FILES_PATH)
assert config_file == config_file_mock


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.download_gh_file")
def test_get_approved_files_config_fails(download_gh_file):
repo = mock.Mock()
download_gh_file.side_effect = github3.exceptions.NotFoundError(mock.Mock())

with pytest.raises(Exception) as exc:
get_approved_files_config(repo)

assert (
# fmt: off
str(exc.value) == f"No config file found. Make sure you have a file saved at {BOT_APPROVED_FILES_PATH}"
)


def test_get_approved_files():
config_file = '{"approved_files": ["file1.py", "file2.py"]}'
approved_files = get_approved_files(config_file)

assert approved_files == ["file1.py", "file2.py"]


def test_get_approved_files_not_json():
config_file = "not a json file"

with pytest.raises(Exception) as exc:
get_approved_files(config_file)

assert str(exc.value) == "Config file is not a valid JSON file"


def test_get_approved_files_no_approved_files():
config_file = '{"another_key": ["file1.py", "file2.py"]}'

with pytest.raises(Exception) as exc:
get_approved_files(config_file)

assert str(exc.value) == "No approved_files key found in config file"


def test_get_approved_files_no_files():
config_file = '{"approved_files": []}'

with pytest.raises(Exception) as exc:
get_approved_files(config_file)

assert str(exc.value) == "No approved files found in config file"


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.get_changed_files")
@mock.patch(
"repo_policies.bot_checks.check_bot_approved_files.get_approved_files_config"
)
@mock.patch("github3.login")
def test_pr_is_blocked_false(gh_login, get_approved_files_config, get_changed_files):
env_vars = {
"GH_TOKEN": "token",
"GH_ORG": "org",
"REPO": "repo",
"MERGE_BASE_SHA": "base",
"BRANCH_HEAD_SHA": "head",
}
gh = mock.Mock()
gh_login.return_value = gh
repo = mock.Mock()
gh.repository.return_value = repo
get_changed_files.return_value = ["file1.py", "file2.py"]
get_approved_files_config.return_value = (
'{"approved_files": ["file1.py", "file2.py", "file3.py"]}'
)

blocked = pr_is_blocked(env_vars)

assert blocked is False
get_changed_files.assert_called_once_with("base", "head")
get_approved_files_config.assert_called_once_with(repo)


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.get_changed_files")
@mock.patch(
"repo_policies.bot_checks.check_bot_approved_files.get_approved_files_config"
)
@mock.patch("github3.login")
def test_pr_is_blocked_true(gh_login, get_approved_files_config, get_changed_files):
env_vars = {
"GH_TOKEN": "token",
"GH_ORG": "org",
"REPO": "repo",
"MERGE_BASE_SHA": "base",
"BRANCH_HEAD_SHA": "head",
}
gh = mock.Mock()
gh_login.return_value = gh
repo = mock.Mock()
gh.repository.return_value = repo
get_changed_files.return_value = ["file1.py", "file2.py"]
get_approved_files_config.return_value = '{"approved_files": ["file1.py"]}'

blocked = pr_is_blocked(env_vars)

assert blocked is True
get_changed_files.assert_called_once_with("base", "head")
get_approved_files_config.assert_called_once_with(repo)


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.load_env_vars")
@mock.patch("repo_policies.bot_checks.check_bot_approved_files.is_approved_bot")
@mock.patch("repo_policies.bot_checks.check_bot_approved_files.pr_is_blocked")
@mock.patch("subprocess.run")
def test_main_succeeds(subprocess_run, pr_is_blocked, is_approved_bot, load_env_vars):
env_vars = {"GH_TOKEN": "token", "USER": "user"}
load_env_vars.return_value = env_vars
is_approved_bot.return_value = True
pr_is_blocked.return_value = False

main()

subprocess_run.assert_called_once_with(
"echo 'block_pr=False' >> $GITHUB_OUTPUT", shell=True
)

@mock.patch("repo_policies.bot_checks.check_bot_approved_files.load_env_vars")
@mock.patch("repo_policies.bot_checks.check_bot_approved_files.is_approved_bot")
@mock.patch("repo_policies.bot_checks.check_bot_approved_files.pr_is_blocked")
@mock.patch("subprocess.run")
def test_main_not_a_bot(subprocess_run, pr_is_blocked, is_approved_bot, load_env_vars):
env_vars = {"GH_TOKEN": "token", "USER": "user"}
load_env_vars.return_value = env_vars
is_approved_bot.return_value = False

main()

subprocess_run.assert_called_once_with(
"echo 'block_pr=False' >> $GITHUB_OUTPUT", shell=True
)
pr_is_blocked.assert_not_called()
18 changes: 17 additions & 1 deletion reusable_workflows/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from unittest import mock

import pytest

from shared.utils import download_gh_file
from shared.utils import download_gh_file, load_env_vars


def test_download_file_succeeds_first_try():
Expand Down Expand Up @@ -46,3 +47,18 @@ def test_download_file_fails(mock_get):

assert repo.file_contents.call_count == 5
file_content_obj.decoded.assert_not_called


@mock.patch.dict(os.environ, {"REPO": "repo-1", "GH_TOKEN": "token"})
def test_load_env_vars_succeeds(capfd):
env_vars = load_env_vars(["REPO", "GH_TOKEN"])

assert env_vars == {"REPO": "repo-1", "GH_TOKEN": "token"}


@mock.patch.dict(os.environ, {"REPO": "repo-1"}, clear=True)
def test_load_env_vars_fails(capfd):
with pytest.raises(Exception) as exc:
env_vars = load_env_vars(["REPO", "GH_TOKEN"])
print(env_vars)
assert str(exc.value) == "Environment variable 'GH_TOKEN' is not set."

0 comments on commit 827d94c

Please sign in to comment.