diff --git a/src/codemodder/codemods/regex_transformer.py b/src/codemodder/codemods/regex_transformer.py index 061a302a..efc1c477 100644 --- a/src/codemodder/codemods/regex_transformer.py +++ b/src/codemodder/codemods/regex_transformer.py @@ -23,26 +23,17 @@ def __init__( self.replacement = replacement self.change_description = change_description - def apply( - self, - context: CodemodExecutionContext, - file_context: FileContext, - results: list[Result] | None, - ) -> ChangeSet | None: + def _apply_regex(self, line): + return re.sub(self.pattern, self.replacement, line) + + def _apply(self, original_lines, file_context, results): del results changes = [] updated_lines = [] - original_lines = ( - file_context.file_path.read_bytes() - .decode("utf-8") - .splitlines(keepends=True) - ) - for lineno, line in enumerate(original_lines): - # TODO: use results to filter out which lines to change - changed_line = re.sub(self.pattern, self.replacement, line) + changed_line = self._apply_regex(line) updated_lines.append(changed_line) if line != changed_line: changes.append( @@ -52,6 +43,22 @@ def apply( findings=file_context.get_findings_for_location(lineno), ) ) + return changes, updated_lines + + def apply( + self, + context: CodemodExecutionContext, + file_context: FileContext, + results: list[Result] | None, + ) -> ChangeSet | None: + + original_lines = ( + file_context.file_path.read_bytes() + .decode("utf-8") + .splitlines(keepends=True) + ) + + changes, updated_lines = self._apply(original_lines, file_context, results) if not changes: logger.debug("No changes produced for %s", file_context.file_path) @@ -67,3 +74,46 @@ def apply( diff=diff, changes=changes, ) + + +class SastRegexTransformerPipeline(RegexTransformerPipeline): + def line_matches_result(self, lineno: int, result_linenums: list[int]) -> bool: + return lineno in result_linenums + + def report_unfixed(self, file_context: FileContext, line_number: int, reason: str): + findings = file_context.get_findings_for_location(line_number) + file_context.add_unfixed_findings(findings, reason, line_number) + + def _apply(self, original_lines, file_context, results): + changes = [] + updated_lines = [] + if results is not None and not results: + return changes, updated_lines + + result_linenums = [ + location.start.line for result in results for location in result.locations + ] + for lineno, line in enumerate(original_lines): + if self.line_matches_result(one_idx_lineno := lineno + 1, result_linenums): + changed_line = self._apply_regex(line) + updated_lines.append(changed_line) + if line == changed_line: + logger.warn("Unable to update html line: %s", line) + self.report_unfixed( + file_context, + one_idx_lineno, + reason="Unable to update html line", + ) + continue + + changes.append( + Change( + lineNumber=lineno + 1, + description=self.change_description, + findings=file_context.get_findings_for_location(lineno), + ) + ) + + else: + updated_lines.append(line) + return changes, updated_lines diff --git a/tests/test_regex_transformer.py b/tests/test_regex_transformer.py index 7c7b6337..b9988ae0 100644 --- a/tests/test_regex_transformer.py +++ b/tests/test_regex_transformer.py @@ -1,8 +1,12 @@ import logging -from codemodder.codemods.regex_transformer import RegexTransformerPipeline +from codemodder.codemods.regex_transformer import ( + RegexTransformerPipeline, + SastRegexTransformerPipeline, +) from codemodder.context import CodemodExecutionContext from codemodder.file_context import FileContext +from codemodder.semgrep import SemgrepResult def test_transformer_no_change(mocker, caplog, tmp_path_factory): @@ -106,3 +110,71 @@ def test_transformer_windows_carriage(mocker, tmp_path_factory): assert changeset is not None assert code.read_bytes() == text.replace(b"world", b"Earth") assert changeset.changes[0].lineNumber == 1 + + +def test_sast_transformer(mocker, tmp_path_factory): + base_dir = tmp_path_factory.mktemp("foo") + code = base_dir / "code.py" + text = "# Something that will match pattern hello" + code.write_text(text) + + file_context = FileContext( + base_dir, + code, + ) + execution_context = CodemodExecutionContext( + directory=base_dir, + dry_run=False, + verbose=False, + registry=mocker.MagicMock(), + providers=mocker.MagicMock(), + repo_manager=mocker.MagicMock(), + path_include=[], + path_exclude=[], + ) + pipeline = SastRegexTransformerPipeline( + pattern=r"hello", replacement="bye", change_description="testing" + ) + + data = { + "runs": [ + { + "results": [ + { + "fingerprints": {"matchBasedId/v1": "123"}, + "locations": [ + { + "ruleId": "rule", + "physicalLocation": { + "artifactLocation": { + "uri": "code.py", + "uriBaseId": "%SRCROOT%", + }, + "region": { + "snippet": {"text": "snip"}, + "endColumn": 1, + "endLine": 1, + "startColumn": 1, + "startLine": 1, + }, + }, + } + ], + "ruleId": "rule", + } + ] + } + ] + } + sarif_run = data["runs"] + sarif_results = sarif_run[0]["results"] + results = [SemgrepResult.from_sarif(sarif_results[0], sarif_run)] + + changeset = pipeline.apply( + context=execution_context, + file_context=file_context, + results=results, + ) + assert changeset is not None + assert code.read_text() == text.replace("hello", "bye") + assert changeset.changes[0].lineNumber == 1