Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: semantic search for large repos vector store toolkit #23

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5870e33
tests passing
michaelneale Aug 28, 2024
e16e4ee
add plugin
michaelneale Aug 28, 2024
1439926
status for vectors
michaelneale Aug 28, 2024
1f6742e
progress
michaelneale Aug 28, 2024
3c7a5f2
working but messy
michaelneale Aug 28, 2024
2b95bd5
tidier
michaelneale Aug 28, 2024
cf86d5c
set model
michaelneale Aug 28, 2024
d6ddec8
add more language extension types
michaelneale Aug 28, 2024
8285b90
set cleanup tokenization spaces to true
michaelneale Aug 28, 2024
63de6b3
small fixes for tensorflow future compatibility
michaelneale Aug 28, 2024
0635c5f
now is working
michaelneale Aug 29, 2024
71cb303
shifting to faster model
michaelneale Aug 29, 2024
2418b99
trimming list of files
michaelneale Aug 29, 2024
eff8ee1
only create vector db once her file heirarchy
michaelneale Aug 30, 2024
f5f93b9
switching models for more semantic similarity
michaelneale Aug 30, 2024
bd2d2b6
restoring weights only
michaelneale Sep 2, 2024
46443d5
make it smarter with embeddings and paths
michaelneale Sep 3, 2024
2374bdd
use faiss and better matching
michaelneale Sep 12, 2024
8b98472
Merge remote-tracking branch 'origin/main' into vector_store
michaelneale Sep 12, 2024
0dfd570
use faiss and better prompt
michaelneale Sep 12, 2024
3da04fe
fixing up types
michaelneale Sep 12, 2024
12f87be
more fixing
michaelneale Sep 12, 2024
b198eb9
Merge branch 'main' into vector_store
michaelneale Sep 16, 2024
9bcb9a9
making vector search stuff optional
michaelneale Sep 16, 2024
9a488cd
need to include extras in ci
michaelneale Sep 16, 2024
016d940
Merge remote-tracking branch 'origin/main' into vector_store
michaelneale Oct 2, 2024
ea5992a
updating to main
michaelneale Oct 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ dependencies = [
"ai-exchange>=0.8.0",
"click>=8.1.7",
"prompt-toolkit>=3.0.47",
"sentence-transformers>=3.0.1",
michaelneale marked this conversation as resolved.
Show resolved Hide resolved
"torch>=2.4.0",
]
author = [{ name = "Block", email = "[email protected]" }]
packages = [{ include = "goose", from = "src" }]
Expand All @@ -23,6 +25,7 @@ developer = "goose.toolkit.developer:Developer"
github = "goose.toolkit.github:Github"
screen = "goose.toolkit.screen:Screen"
repo_context = "goose.toolkit.repo_context.repo_context:RepoContext"
vector = "goose.toolkit.vector:VectorToolkit"

[project.entry-points."goose.profile"]
default = "goose.profile:default_profile"
Expand All @@ -45,4 +48,3 @@ dev-dependencies = [
"pytest>=8.3.2",
"codecov>=2.1.13",
]

130 changes: 130 additions & 0 deletions src/goose/toolkit/vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import tempfile
import torch
import hashlib
from goose.toolkit.base import Toolkit, tool
from sentence_transformers import SentenceTransformer, util
from goose.cli.session import SessionNotifier
from pathlib import Path


GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can import GOOSE_GLOBAL_PATH from config.py

VECTOR_PATH = GOOSE_GLOBAL_PATH.joinpath("vectors")


class VectorToolkit(Toolkit):

model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', tokenizer_kwargs={"clean_up_tokenization_spaces": True})

def get_db_path(self, repo_path):
# Create a hash of the repo path
repo_hash = hashlib.md5(repo_path.encode()).hexdigest()
return VECTOR_PATH.joinpath(f'code_vectors_{repo_hash}.pt')

def create_vector_db(self, repo_path: str) -> str:
"""
Create a vector database of the code in the specified directory and store it in a temp file.

Args:
repo_path (str): Path to the source code repository.

Returns:
str: Path to the created vector database file.
"""
temp_db_path = self.get_db_path(repo_path)
VECTOR_PATH.mkdir(parents=True, exist_ok=True)
self.notifier.status("Preparing vector database :: Scanning repository (first time may take a while, please wait...)")
file_paths, file_contents = self.scan_repository(repo_path)
self.notifier.status("Preparing vector database :: Building vectors (first time may take a while, please wait...)")
embeddings = self.build_vector_database(file_contents)
self.notifier.status("Saving vector database...")
self.save_vector_database(file_paths, embeddings, temp_db_path)
self.notifier.status("Completed vector database creation")
return temp_db_path

@tool
def query_vector_db(self, repo_path: str, query: str) -> str:
"""
Locate files in a repository that are potentially semantically related to the query and may hint where to look.

Args:
repo_path (str): The repository that we will be searching in
query (str): Query string to search for semantically related files or paths.
Returns:
str: List of semantically relevant files to look in, also consider the paths the files are in.
"""
temp_db_path = self.lookup_db_path(repo_path)
if temp_db_path is None:
temp_db_path = self.create_vector_db(repo_path)
self.notifier.status("Loading vector database...")
file_paths, embeddings = self.load_vector_database(temp_db_path)
self.notifier.status("Performing query...")
similar_files = self.find_similar_files(query, file_paths, embeddings)
return '\n'.join(similar_files)

def lookup_db_path(self, repo_path: str) -> str:
"""
Check if a vector database exists for the given repository path or its parent directories.

Args:
repo_path (str): Path to the source code repository.

Returns:
str: Path to the existing vector database file, or None if none found.
"""
current_path = Path(repo_path).expanduser()
while current_path != current_path.parent:
temp_db_path = self.get_db_path(str(current_path))
if os.path.exists(temp_db_path):
return temp_db_path
current_path = current_path.parent
return None

def scan_repository(self, repo_path):
repo_path = Path(repo_path).expanduser()
file_contents = []
file_paths = []
skipped_file_types = {}
for root, dirs, files in os.walk(repo_path):
# Exclude dotfile directories
dirs[:] = [d for d in dirs if not d.startswith('.')]
for file in files:
file_extension = os.path.splitext(file)[1]
if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.md', '.dart', '.kt', '.swift', '.scala', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.clj', '.lisp']:
file_path = os.path.join(root, file)
file_paths.append(file_path)
try:
with open(file_path, 'r', errors='ignore') as f:
content = f.read()
file_contents.append(content)
except Exception as e:
print(f'Error reading {file_path}: {e}')
else:
skipped_file_types[file_extension] = True
return file_paths, file_contents

def build_vector_database(self, file_contents):
embeddings = self.model.encode(file_contents, convert_to_tensor=True)
return embeddings

def save_vector_database(self, file_paths, embeddings, db_path):
torch.save({'file_paths': file_paths, 'embeddings': embeddings}, db_path)

def load_vector_database(self, db_path):
if db_path is not None and os.path.exists(db_path):
data = torch.load(db_path)
else:
raise ValueError(f"Database path {db_path} does not exist.")
return data['file_paths'], data['embeddings']

def find_similar_files(self, query, file_paths, embeddings):
michaelneale marked this conversation as resolved.
Show resolved Hide resolved
query_embedding = self.model.encode([query], convert_to_tensor=True)
if embeddings.size(0) == 0:
return 'No embeddings available to query against'
scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
top_results = torch.topk(scores, k=10)
similar_files = [file_paths[idx] for idx in top_results[1]]
return similar_files

def system(self) -> str:
return "**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**"
61 changes: 61 additions & 0 deletions tests/toolkit/test_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
import os
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock

import pytest
from goose.toolkit.vector import VectorToolkit

@pytest.fixture
def temp_dir():
with TemporaryDirectory() as temp_dir:
yield Path(temp_dir)

@pytest.fixture
def vector_toolkit():
return VectorToolkit(notifier=MagicMock())

def test_query_vector_db_creates_db(temp_dir, vector_toolkit):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use tmp_path directly instead of temp_dir.

tmp_path is the built-in fixture in pytest. https://docs.pytest.org/en/latest/how-to/tmp_path.html#tmp-path

# Create and load a vector database lazily
query = 'print("Hello World")'
result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query)
print("Query Result:", result)
assert isinstance(result, str)
temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix())
assert os.path.exists(temp_db_path)
assert os.path.getsize(temp_db_path) > 0


def test_query_vector_db(temp_dir, vector_toolkit):
# Create initial db
vector_toolkit.create_vector_db(temp_dir.as_posix())
query = 'print("Hello World")'
result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query)
print("Query Result:", result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excuse python noob.. do we want these prints? I guess they aren't visible by default, so it doesn't matter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you have to run pytest in another mode to see them

assert isinstance(result, str)
temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix())
assert os.path.exists(temp_db_path)
assert os.path.getsize(temp_db_path) > 0
assert 'No embeddings available to query against' in result or '\n' in result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose in the future, we could make an integration test with ollama for this one, or possibly an in-memory embeddings lib?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah - something scaled down and deterministic ideally



def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit):
# Create a vector DB at a higher level
higher_dir = temp_dir / "higher"
higher_dir.mkdir()
db_path_higher = vector_toolkit.create_vector_db(higher_dir.as_posix())

# Now create a lower directory
lower_dir = higher_dir / "lower"
lower_dir.mkdir()

# Perform query on the lower directory
query = 'print("Hello World")'
result = vector_toolkit.query_vector_db(lower_dir.as_posix(), query)
print("Query Result from Lower Directory:", result)

# Ensure a DB at the lower level is not created
temp_db_path_lower = vector_toolkit.get_db_path(lower_dir.as_posix())
assert not os.path.exists(temp_db_path_lower)
assert os.path.exists(db_path_higher)
assert os.path.getsize(db_path_higher) > 0