-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: semantic search for large repos vector store toolkit #23
Changes from all commits
5870e33
e16e4ee
1439926
1f6742e
3c7a5f2
2b95bd5
cf86d5c
d6ddec8
8285b90
63de6b3
0635c5f
71cb303
2418b99
eff8ee1
f5f93b9
bd2d2b6
46443d5
2374bdd
8b98472
0dfd570
3da04fe
12f87be
b198eb9
9bcb9a9
9a488cd
016d940
ea5992a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,20 @@ dependencies = [ | |
"rich>=13.7.1", | ||
"ruamel-yaml>=0.18.6", | ||
"click>=8.1.7", | ||
"prompt-toolkit>=3.0.47", | ||
"prompt-toolkit>=3.0.47" | ||
] | ||
author = [{ name = "Block", email = "[email protected]" }] | ||
packages = [{ include = "goose", from = "src" }] | ||
|
||
|
||
[project.optional-dependencies] | ||
vector = [ | ||
"sentence-transformers>=3.0.1", | ||
"torch>=2.4.0", | ||
"faiss-cpu" | ||
] | ||
|
||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/goose"] | ||
|
||
|
@@ -27,6 +36,7 @@ github = "goose.toolkit.github:Github" | |
jira = "goose.toolkit.jira:Jira" | ||
screen = "goose.toolkit.screen:Screen" | ||
repo_context = "goose.toolkit.repo_context.repo_context:RepoContext" | ||
vector = "goose.toolkit.vector:VectorToolkit" | ||
|
||
[project.entry-points."goose.profile"] | ||
default = "goose.profile:default_profile" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
When navigating a codebase, it can be useful to search for conceptually related items when needed. | ||
|
||
To do this you can use the find_similar_files_locations tool. It will provide potentially related paths and files to look at. | ||
This should not be relied on alone, but in tandem with other tools |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
import torch | ||
import hashlib | ||
from goose.toolkit.base import Toolkit, tool | ||
from sentence_transformers import SentenceTransformer | ||
from pathlib import Path | ||
from exchange import Message | ||
import faiss | ||
|
||
|
||
GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() | ||
VECTOR_PATH = GOOSE_GLOBAL_PATH.joinpath("vectors") | ||
|
||
|
||
class VectorToolkit(Toolkit): | ||
"""Use embeddings for finding related concepts in codebase.""" | ||
|
||
_model = None | ||
|
||
@property | ||
def model(self) -> SentenceTransformer: | ||
if self._model is None: | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
self.notifier.status("Preparing local model...") | ||
self._model = SentenceTransformer( | ||
"multi-qa-MiniLM-L6-cos-v1", tokenizer_kwargs={"clean_up_tokenization_spaces": True} | ||
) | ||
return self._model | ||
|
||
def get_db_path(self, repo_path: str) -> Path: | ||
# Create a hash of the repo path | ||
repo_hash = hashlib.md5(repo_path.encode()).hexdigest() | ||
return VECTOR_PATH.joinpath(f"code_vectors_{repo_hash}.pt") | ||
|
||
def create_vector_db(self, repo_path: str) -> str: | ||
""" | ||
Create a vector database of the code in the specified directory and store it in a temp file. | ||
|
||
Args: | ||
repo_path (str): Path to the source code repository. | ||
|
||
Returns: | ||
str: Path to the created vector database file. | ||
""" | ||
temp_db_path = self.get_db_path(repo_path) | ||
VECTOR_PATH.mkdir(parents=True, exist_ok=True) | ||
self.notifier.status("Scanning repository (first time may take a while, please wait...)") | ||
file_paths, file_contents = self.scan_repository(repo_path) | ||
self.notifier.status("Building local emebddings of code (first time may take a while, please wait...)") | ||
embeddings = self.build_vector_database(file_contents) | ||
self.save_vector_database(file_paths, embeddings, temp_db_path) | ||
return temp_db_path | ||
|
||
@tool | ||
def find_similar_files_locations(self, repo_path: str, query: str) -> str: | ||
""" | ||
Locate files and locations in a repository that are conceptually related to the query | ||
and may hint where to look. | ||
Don't rely on this for exhaustive matches around strings, use ripgrep additionally for searching. | ||
|
||
Args: | ||
repo_path (str): The repository that we will be searching in | ||
query (str): Query string to search for semantically related files or paths. | ||
Returns: | ||
str: List of semantically relevant files and paths to consider. | ||
""" | ||
temp_db_path = self.lookup_db_path(repo_path) | ||
if temp_db_path is None: | ||
temp_db_path = self.create_vector_db(repo_path) | ||
self.notifier.status("Loading embeddings database...") | ||
file_paths, embeddings = self.load_vector_database(temp_db_path) | ||
self.notifier.status("Performing query...") | ||
similar_files = self.find_similar_files(query, file_paths, embeddings) | ||
return "\n".join(similar_files) | ||
|
||
def lookup_db_path(self, repo_path: str) -> str: | ||
""" | ||
Check if a vector database exists for the given repository path or its parent directories. | ||
|
||
Args: | ||
repo_path (str): Path to the source code repository. | ||
|
||
Returns: | ||
str: Path to the existing vector database file, or None if none found. | ||
""" | ||
current_path = Path(repo_path).expanduser() | ||
while current_path != current_path.parent: | ||
temp_db_path = self.get_db_path(str(current_path)) | ||
if os.path.exists(temp_db_path): | ||
return temp_db_path | ||
current_path = current_path.parent | ||
return None | ||
|
||
def scan_repository(self, repo_path: Path) -> tuple[list[str], list[str]]: | ||
repo_path = Path(repo_path).expanduser() | ||
file_contents = [] | ||
file_paths = [] | ||
skipped_file_types = {} | ||
for root, dirs, files in os.walk(repo_path): | ||
# Exclude dotfile directories | ||
dirs[:] = [d for d in dirs if not d.startswith(".")] | ||
for file in files: | ||
file_extension = os.path.splitext(file)[1] | ||
if file_extension in [ | ||
".py", | ||
".java", | ||
".js", | ||
".jsx", | ||
".ts", | ||
".tsx", | ||
".cpp", | ||
".c", | ||
".h", | ||
".hpp", | ||
".rb", | ||
".go", | ||
".rs", | ||
".php", | ||
".md", | ||
".dart", | ||
".kt", | ||
".swift", | ||
".scala", | ||
".lua", | ||
".pl", | ||
".r", | ||
".m", | ||
".mm", | ||
".f", | ||
".jl", | ||
".cs", | ||
".vb", | ||
".pas", | ||
".groovy", | ||
".hs", | ||
".elm", | ||
".erl", | ||
".clj", | ||
".lisp", | ||
]: | ||
file_path = os.path.join(root, file) | ||
file_paths.append(file_path) | ||
try: | ||
with open(file_path, "r", errors="ignore") as f: | ||
content = f.read() | ||
file_contents.append(content) | ||
except Exception as e: | ||
print(f"Error reading {file_path}: {e}") | ||
else: | ||
skipped_file_types[file_extension] = True | ||
return file_paths, file_contents | ||
|
||
def build_vector_database(self, file_contents: str) -> list[any]: | ||
embeddings = self.model.encode(file_contents, convert_to_tensor=True) | ||
return embeddings | ||
|
||
def save_vector_database(self, file_paths: list[str], embeddings: list[any], db_path: Path) -> None: | ||
torch.save({"file_paths": file_paths, "embeddings": embeddings}, db_path) | ||
|
||
def load_vector_database(self, db_path: Path) -> tuple[list[str], list[any]]: | ||
if db_path is not None and os.path.exists(db_path): | ||
data = torch.load(db_path, weights_only=True) | ||
else: | ||
raise ValueError(f"Database path {db_path} does not exist.") | ||
return data["file_paths"], data["embeddings"] | ||
|
||
def find_similar_files( | ||
self, query: str, file_paths: list[Path], embeddings: tuple[list[str], list[any]] | ||
) -> list[str]: | ||
if embeddings.size(0) == 0: | ||
return "No embeddings available to query against" | ||
query_embedding = self.model.encode([query], convert_to_tensor=True).cpu().numpy() | ||
embeddings_np = embeddings.cpu().numpy() | ||
index = faiss.IndexFlatL2(embeddings_np.shape[1]) | ||
index.add(embeddings_np) | ||
_, i = index.search(query_embedding, min(10, len(embeddings_np))) | ||
similar_files = [file_paths[idx] for idx in i[0]] | ||
expanded_similar_files = set() | ||
for file in similar_files: | ||
expanded_similar_files.add(file) | ||
parent = Path(file).parent | ||
depth = 0 | ||
while parent != parent.parent and depth < 3: | ||
expanded_similar_files.add(str(parent)) | ||
parent = parent.parent | ||
depth += 1 | ||
return list(expanded_similar_files) | ||
|
||
def system(self) -> str: | ||
"""Retrieve guidelines for semantic search""" | ||
return Message.load("prompts/vector.jinja").text |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from pathlib import Path | ||
import os | ||
from tempfile import TemporaryDirectory | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from goose.toolkit.vector import VectorToolkit | ||
|
||
|
||
@pytest.fixture | ||
def temp_dir(): | ||
with TemporaryDirectory() as temp_dir: | ||
yield Path(temp_dir) | ||
|
||
|
||
@pytest.fixture | ||
def vector_toolkit(): | ||
return VectorToolkit(notifier=MagicMock()) | ||
|
||
|
||
def test_query_vector_db_creates_db(temp_dir, vector_toolkit): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can use
|
||
# Create and load a vector database lazily | ||
query = 'print("Hello World")' | ||
result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) | ||
print("Query Result:", result) | ||
assert isinstance(result, str) | ||
temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) | ||
assert os.path.exists(temp_db_path) | ||
assert os.path.getsize(temp_db_path) > 0 | ||
|
||
|
||
def test_query_vector_db(temp_dir, vector_toolkit): | ||
# Create initial db | ||
vector_toolkit.create_vector_db(temp_dir.as_posix()) | ||
query = 'print("Hello World")' | ||
result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) | ||
print("Query Result:", result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. excuse python noob.. do we want these prints? I guess they aren't visible by default, so it doesn't matter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah you have to run pytest in another mode to see them |
||
assert isinstance(result, str) | ||
temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) | ||
assert os.path.exists(temp_db_path) | ||
assert os.path.getsize(temp_db_path) > 0 | ||
assert "No embeddings available to query against" in result or "\n" in result | ||
|
||
|
||
def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): | ||
# Create a vector DB at a higher level | ||
higher_dir = temp_dir / "higher" | ||
higher_dir.mkdir() | ||
db_path_higher = vector_toolkit.create_vector_db(higher_dir.as_posix()) | ||
|
||
# Now create a lower directory | ||
lower_dir = higher_dir / "lower" | ||
lower_dir.mkdir() | ||
|
||
# Perform query on the lower directory | ||
query = 'print("Hello World")' | ||
result = vector_toolkit.find_similar_files_locations(lower_dir.as_posix(), query) | ||
print("Query Result from Lower Directory:", result) | ||
|
||
# Ensure a DB at the lower level is not created | ||
temp_db_path_lower = vector_toolkit.get_db_path(lower_dir.as_posix()) | ||
assert not os.path.exists(temp_db_path_lower) | ||
assert os.path.exists(db_path_higher) | ||
assert os.path.getsize(db_path_higher) > 0 | ||
|
||
|
||
def test_find_similar_files_in_repo(temp_dir, vector_toolkit): | ||
# Setting up a temporary repository structure | ||
file_structure = { | ||
"file1.py": "def function_one(): pass\n", | ||
"file2.py": "def function_two(): pass\n", | ||
"subdir": {"file3.py": "class MyClass: pass\n"}, | ||
} | ||
|
||
def create_files(base_path, structure): | ||
for name, content in structure.items(): | ||
path = base_path / name | ||
if isinstance(content, str): | ||
with open(path, "w") as f: | ||
f.write(content) | ||
else: | ||
path.mkdir() | ||
create_files(path, content) | ||
|
||
create_files(temp_dir, file_structure) | ||
|
||
# Create initial db | ||
vector_toolkit.create_vector_db(temp_dir.as_posix()) | ||
|
||
# Test query | ||
query = "def function_one" | ||
result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) | ||
print("Similar Files Result:", result) | ||
assert "file1.py" in result or "subdir" in result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can import
GOOSE_GLOBAL_PATH
fromconfig.py