-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from hrshdhgd/bug-finder
Added code for bug finding + fixing
- Loading branch information
Showing
12 changed files
with
395 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Bug-finder module for the package.""" | ||
|
||
from .bug_finder import BugFinder | ||
|
||
__all__ = ["BugFinder"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
"""Bug-finder class for the package.""" | ||
|
||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from langchain_core.runnables.base import RunnableSerializable | ||
|
||
from codergpt.utils import extract_code_from_response | ||
|
||
|
||
class BugFinder: | ||
"""Bug-finder class for the package.""" | ||
|
||
def __init__(self, chain: RunnableSerializable[Dict, Any]): | ||
"""Initialize the BugFinder class.""" | ||
self.chain = chain | ||
|
||
def find_bugs( | ||
self, code: str, function: Optional[str] = None, classname: Optional[str] = None, language: Optional[str] = None | ||
): | ||
""" | ||
Find bugs in the given code. | ||
:param code: The code to find bugs in. | ||
:param function: The name of the function to find bugs in. Default is None. | ||
:param classname: The name of the class to find bugs in. Default is None. | ||
:param language: The language of the code. Default is None. | ||
""" | ||
if function: | ||
response = self.chain.invoke( | ||
{ | ||
"input": f"Find and list all the bugs in the function {function}" | ||
f" in the following {language} code: \n\n```\n{code}\n```" | ||
} | ||
) | ||
# Pretty print the response | ||
print(f"Bugs found in '{function}':\n{response.content}") | ||
elif classname: | ||
response = self.chain.invoke( | ||
{ | ||
"input": f"Find and list all the bugs in the class {classname}" | ||
f" in the following {language} code: \n\n```\n{code}\n```" | ||
} | ||
) | ||
# Pretty print the response | ||
print(f"Bugs found in '{classname}':\n{response.content}") | ||
else: | ||
# Find bugs in full code | ||
response = self.chain.invoke( | ||
{"input": f"Find and list all the bugs in the following {language} code: \n\n```\n{code}\n```"} | ||
) | ||
# Pretty print the response | ||
print(f"Bugs found in the code:\n{response.content}") | ||
|
||
def fix_bugs( | ||
self, | ||
filename: Union[str, Path], | ||
code: str, | ||
function: Optional[str] = None, | ||
classname: Optional[str] = None, | ||
language: Optional[str] = None, | ||
outfile: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Fix bugs in the given code. | ||
:param code: The code to fix bugs in. | ||
:param function: The name of the function to fix bugs in. Default is None. | ||
:param classname: The name of the class to fix bugs | ||
:param outfile:Path for output file with bug-fix code. Default is None. | ||
""" | ||
if function: | ||
response = self.chain.invoke( | ||
{ | ||
"input": f"List all the bug fixes if any and rewrite the function {function}" | ||
f" in the following {language} code: \n\n```\n{code}\n```" | ||
} | ||
) | ||
# Pretty print the response | ||
print(f"Fixed code for '{function}':\n{response.content}") | ||
return response.content | ||
elif classname: | ||
response = self.chain.invoke( | ||
{ | ||
"input": f"List all the bug fixes if any and rewrite the class {classname}" | ||
f" in the following {language} code: \n\n```\n{code}\n```" | ||
} | ||
) | ||
# Pretty print the response | ||
print(f"Fixed code for '{classname}':\n{response.content}") | ||
return response.content | ||
else: | ||
# Fix bugs in full code | ||
response = self.chain.invoke( | ||
{ | ||
"input": f"List all the bug fixes if any and rewrite the following {language}" | ||
f" code: \n\n```\n{code}\n```" | ||
} | ||
) | ||
return extract_code_from_response(language, response.content, filename, outfile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Utility functions for the codergpt package.""" | ||
|
||
import os | ||
import re | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import yaml | ||
|
||
from codergpt.constants import EXTENSION_MAP_FILE | ||
|
||
|
||
def extract_code_from_response( | ||
language: str, response: str, filename: Union[str, Path], outfile: Optional[str] = None | ||
) -> str: | ||
""" | ||
Generate code files based on LLM responses. | ||
:param language: Code language. | ||
:param response: LLM response. | ||
:param filename: Source code file. | ||
:param outfile: Destination filepath, defaults to None | ||
""" | ||
base, ext = os.path.splitext(filename) | ||
file_parent = Path(filename).parent | ||
|
||
if not language: | ||
get_language_from_extension(filename) | ||
|
||
code_pattern_block = rf"```{language.lower()}(.*?)(?<=\n)```" | ||
matches = re.findall(code_pattern_block, response, re.DOTALL) | ||
|
||
if matches: | ||
code_to_save = matches[0].strip() | ||
if not outfile: | ||
outfile = f"{file_parent/base}_updated{ext}" | ||
with open(outfile, "w") as file: | ||
file.write(code_to_save) | ||
print(f"Fixed code saved in file: {outfile}") | ||
|
||
print(response) | ||
return response | ||
|
||
|
||
def get_language_from_extension(filename: Union[str, Path]) -> Optional[str]: | ||
""" | ||
Get the language of a file from its extension. | ||
:param filename: The filename to get the language for. | ||
:return: The language of the file, if found. | ||
""" | ||
with open(EXTENSION_MAP_FILE, "r") as file: | ||
extension_to_language = yaml.safe_load(file) | ||
language = extension_to_language["language-map"].get(Path(filename).suffix) | ||
return language |
Oops, something went wrong.