Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds commit classification rule #397

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion prospector/llm/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import validators
from langchain_core.language_models.llms import LLM
from langchain_core.output_parsers import StrOutputParser
from requests import HTTPError

from llm.instantiation import create_model_instance
from llm.prompts import prompt_best_guess
from llm.prompts.classify_commit import zero_shot as cc_zero_shot
from llm.prompts.get_repository_url import prompt_best_guess
from log.logger import logger
from util.config_parser import LLMServiceConfig
from util.singleton import Singleton
Expand Down Expand Up @@ -74,3 +76,53 @@ def get_repository_url(self, advisory_description, advisory_references) -> str:
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")

return url

def classify_commit(
self, diff: str, repository_name: str, commit_message: str
) -> bool:
"""Ask an LLM whether a commit is security relevant or not. The response will be either True or False.

Args:
candidate (Commit): The commit to input into the LLM

Returns:
True if the commit is deemed security relevant, False if not.

Raises:
ValueError if there is an error in the model invocation or the response was not valid.
"""
try:
chain = cc_zero_shot | self.model | StrOutputParser()

is_relevant = chain.invoke(
{
"diff": diff,
"repository_name": repository_name,
"commit_message": commit_message,
}
)
logger.info(f"LLM returned is_relevant={is_relevant}")

except HTTPError as e:
# if the diff is too big, a 400 error is returned -> silently ignore by returning False for this commit
status_code = e.response.status_code
if status_code == 400:
return False
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")
except Exception as e:
raise RuntimeError(f"Prompt-model chain could not be invoked: {e}")

if is_relevant in [
"True",
"ANSWER:True",
"```ANSWER:True```",
]:
return True
elif is_relevant in [
"False",
"ANSWER:False",
"```ANSWER:False```",
]:
return False
else:
raise RuntimeError(f"The model returned an invalid response: {is_relevant}")
16 changes: 16 additions & 0 deletions prospector/llm/prompts/classify_commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from langchain.prompts import PromptTemplate

zero_shot = PromptTemplate.from_template(
"""Is the following commit security relevant or not?
Please provide the output as a boolean value, either True or False.
If it is security relevant just answer True otherwise answer False. Do not return anything else.

To provide you with some context, the name of the repository is: {repository_name}, and the
commit message is: {commit_message}.

Finally, here is the diff of the commit:
{diff}\n


Your answer:\n"""
)
File renamed without changes.
16 changes: 15 additions & 1 deletion prospector/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,18 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
return False


class CommitIsSecurityRelevant(Rule):
"""Matches commits that are deemed security relevant by the commit classification service."""

def apply(
self,
candidate: Commit,
) -> bool:
return LLMService().classify_commit(
candidate.diff, candidate.repository, candidate.message
)


RULES_PHASE_1: List[Rule] = [
VulnIdInMessage("VULN_ID_IN_MESSAGE", 64),
# CommitMentionedInAdv("COMMIT_IN_ADVISORY", 64),
Expand All @@ -433,4 +445,6 @@ def apply(self, candidate: Commit, advisory_record: AdvisoryRecord):
CommitHasTwins("COMMIT_HAS_TWINS", 2),
]

RULES_PHASE_2: List[Rule] = []
RULES_PHASE_2: List[Rule] = [
CommitIsSecurityRelevant("COMMIT_IS_SECURITY_RELEVANT", 32)
]
24 changes: 18 additions & 6 deletions prospector/rules/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def candidates():
changed_files={
"core/src/main/java/org/apache/cxf/workqueue/AutomaticWorkQueueImpl.java"
},
minhash=get_encoded_minhash(get_msg("Insecure deserialization", 50)),
minhash=get_encoded_minhash(
get_msg("Insecure deserialization", 50)
),
),
# TODO: Not matched by existing tests: GHSecurityAdvInMessage, ReferencesBug, ChangesRelevantCode, TwinMentionedInAdv, VulnIdInLinkedIssue, SecurityKeywordInLinkedGhIssue, SecurityKeywordInLinkedBug, CrossReferencedBug, CrossReferencedGh, CommitHasTwins, ChangesRelevantFiles, CommitMentionedInAdv, RelevantWordsInMessage
]
Expand All @@ -109,37 +111,47 @@ def advisory_record():
)


def test_apply_phase_1_rules(candidates: List[Commit], advisory_record: AdvisoryRecord):
def test_apply_phase_1_rules(
candidates: List[Commit], advisory_record: AdvisoryRecord
):
annotated_candidates = apply_rules(
candidates, advisory_record, enabled_rules=enabled_rules_from_config
)

# Repo 5: Should match: AdvKeywordsInFiles, SecurityKeywordsInMsg, CommitMentionedInReference
assert len(annotated_candidates[0].matched_rules) == 3

matched_rules_names = [item["id"] for item in annotated_candidates[0].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[0].matched_rules
]
assert "ADV_KEYWORDS_IN_FILES" in matched_rules_names
assert "COMMIT_IN_REFERENCE" in matched_rules_names
assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names

# Repo 1: Should match: VulnIdInMessage, ReferencesGhIssue
assert len(annotated_candidates[1].matched_rules) == 2

matched_rules_names = [item["id"] for item in annotated_candidates[1].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[1].matched_rules
]
assert "VULN_ID_IN_MESSAGE" in matched_rules_names
assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names

# Repo 3: Should match: VulnIdInMessage, ReferencesGhIssue
assert len(annotated_candidates[2].matched_rules) == 2

matched_rules_names = [item["id"] for item in annotated_candidates[2].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[2].matched_rules
]
assert "VULN_ID_IN_MESSAGE" in matched_rules_names
assert "GITHUB_ISSUE_IN_MESSAGE" in matched_rules_names

# Repo 4: Should match: SecurityKeywordsInMsg
assert len(annotated_candidates[3].matched_rules) == 1

matched_rules_names = [item["id"] for item in annotated_candidates[3].matched_rules]
matched_rules_names = [
item["id"] for item in annotated_candidates[3].matched_rules
]
assert "SEC_KEYWORDS_IN_MESSAGE" in matched_rules_names

# Repo 2: Matches nothing
Expand Down
Loading