diff --git a/src/helpers/issue_helper.py b/src/helpers/issue_helper.py index 3ac5cae..7c24dac 100644 --- a/src/helpers/issue_helper.py +++ b/src/helpers/issue_helper.py @@ -10,7 +10,7 @@ def has_tasklist(issue_body: str) -> bool: """Return if the issue has a tasklist""" - return bool(re.search(r"- \[(.)] (.*)", issue_body)) + return bool(issue_body) and bool(re.search(r"- \[(.)] (.*)", issue_body)) def get_tasklist(issue_body: str) -> list[tuple[str, bool]]: diff --git a/src/managers/issue_manager.py b/src/managers/issue_manager.py index 1a59b94..449e8c9 100644 --- a/src/managers/issue_manager.py +++ b/src/managers/issue_manager.py @@ -56,10 +56,12 @@ def get_or_create_issue_job(event: IssuesEvent) -> IssueJob: @Config.call_if("issue_manager.enabled") def manage(event: IssuesEvent) -> Optional[IssueJob]: """Manage an issue or they task list.""" - if isinstance(event, (IssueOpenedEvent, IssueEditedEvent)): - return handle_task_list(event) - if isinstance(event, IssueClosedEvent): - close_sub_tasks(event) + issue = event.issue + if issue_helper.has_tasklist(issue.body): + if isinstance(event, (IssueOpenedEvent, IssueEditedEvent)): + return handle_task_list(event) + if isinstance(event, IssueClosedEvent): + close_sub_tasks(event) return None @@ -67,9 +69,7 @@ def manage(event: IssuesEvent) -> Optional[IssueJob]: def handle_task_list(event: IssuesEvent) -> Optional[IssueJob]: """Handle the task list of an issue.""" issue = event.issue - if not (tasklist := issue_helper.get_tasklist(issue.body or "")): - return None - issue_job = get_or_create_issue_job(event) + tasklist = issue_helper.get_tasklist(issue.body) existing_jobs = {} created_issues = {} for j in JobService.filter(original_issue_url=issue.url): @@ -77,6 +77,7 @@ def handle_task_list(event: IssuesEvent) -> Optional[IssueJob]: if j.issue_ref: created_issues[j.issue_ref] = j jobs = [] + for task, checked in tasklist: if task in existing_jobs: continue @@ -97,6 +98,7 @@ def handle_task_list(event: IssuesEvent) -> Optional[IssueJob]: if jobs: JobService.insert_many(jobs) + issue_job = get_or_create_issue_job(event) if issue_job.issue_job_status == IssueJobStatus.DONE: IssueJobService.update(issue_job, issue_job_status=IssueJobStatus.PENDING) return issue_job @@ -309,9 +311,11 @@ def close_issue_if_all_checked(issue_job: IssueJob) -> NoReturn: issue_job.installation_id, issue_job.issue_url, ) - tasklist = issue_helper.get_tasklist(issue.body) - if tasklist and all(checked for _, checked in tasklist): - issue.edit(state="closed") + issue_body = issue.body + if issue_helper.has_tasklist(issue_body): + tasklist = issue_helper.get_tasklist(issue_body) + if tasklist and all(checked for _, checked in tasklist): + issue.edit(state="closed") def process_update_progress(issue_job: IssueJob) -> NoReturn: diff --git a/tests/helpers/test_issue_helper.py b/tests/helpers/test_issue_helper.py index 0f947bd..605a051 100644 --- a/tests/helpers/test_issue_helper.py +++ b/tests/helpers/test_issue_helper.py @@ -13,9 +13,10 @@ def _task_list_data(): + # "issue_body, expected_tasks, assert_fail_message" return { "argvalues": [ - ("", [], "Empty string should not have a tasklist"), + (None, [], "Empty string should not have a tasklist"), ( "This is a sample issue description.", [], @@ -62,7 +63,7 @@ def test_has_tasklist(issue_body, expected, assert_fail_message): ) def test_get_tasklist(issue_body, expected_tasks, assert_fail_message): """Test get_tasklist with various inputs using parametrize""" - result = get_tasklist(issue_body) + result = get_tasklist(issue_body or "") assert result == expected_tasks, assert_fail_message diff --git a/tests/managers/test_issue_manager.py b/tests/managers/test_issue_manager.py index 5add011..57341a2 100644 --- a/tests/managers/test_issue_manager.py +++ b/tests/managers/test_issue_manager.py @@ -77,24 +77,25 @@ def test_get_or_create_issue_job( @pytest.mark.parametrize( - "event, handle_task_list_called, close_sub_tasks_called", + "event, handle_task_list_called, close_sub_tasks_called, has_task_list", [ - (IssueOpenedEvent, True, False), - (IssueEditedEvent, True, False), - (IssueClosedEvent, False, True), - (None, False, False), # Any other type of event + (IssueOpenedEvent, True, False, True), + (IssueOpenedEvent, False, False, False), + (IssueEditedEvent, True, False, True), + (IssueClosedEvent, False, True, True), + (IssueClosedEvent, False, False, False), + (None, False, False, True), # Any other type of event ], ) def test_manage( - event, - handle_task_list_called, - close_sub_tasks_called, + event, handle_task_list_called, close_sub_tasks_called, has_task_list, issue_helper ): + issue_helper.has_tasklist.return_value = has_task_list with ( patch("src.managers.issue_manager.handle_task_list") as handle_task_list_mock, patch("src.managers.issue_manager.close_sub_tasks") as close_sub_tasks_mock, ): - manage(Mock(spec=event)) + manage(Mock(spec=event, issue=Mock())) assert handle_task_list_mock.called == handle_task_list_called assert close_sub_tasks_mock.called == close_sub_tasks_called @@ -123,11 +124,6 @@ def test_manage( ], IssueJobStatus.DONE, ], - [ - [], - [], - None, - ], [ [("task1", False), ("task2", False)], [ @@ -141,7 +137,6 @@ def test_manage( "All new tasks", "Existing tasks and add new task", "Editing issue, with new task", - "No task list", "No new task in task list", ], ) @@ -151,6 +146,7 @@ def test_handle_task_list( for existing_task in existing_tasks: JobService.insert_one(Job(original_issue_url=event.issue.url, **existing_task)) + issue_helper.has_tasklist.return_value = bool(tasks) issue_helper.get_tasklist.return_value = tasks issue_job = Mock(issue_job_status=issue_job_status) with patch( @@ -458,15 +454,18 @@ def test_process_update_issue_body(issue_job): [[("task1", False), ("task2", False)], False], [[("task1", True), ("task2", True)], True], [[("task1", True), ("task2", False)], False], + [[], False], ], ids=[ "All opened", "All closed", "1 closed", + "No tasks", ], ) def test_close_issue_if_all_checked(tasks, issue_job, issue_helper, should_close): issue = Mock() + issue_helper.has_tasklist.return_value = bool(tasks) issue_helper.get_tasklist.return_value = tasks with (