From 5870e33d6db6bf6bcaf06191410c1f891d26d29b Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 15:54:40 +1000 Subject: [PATCH 01/23] tests passing --- pyproject.toml | 2 + src/goose/toolkit/vector.py | 94 ++++++++++++++++++++++++++++++++++++ tests/toolkit/test_vector.py | 42 ++++++++++++++++ 3 files changed, 138 insertions(+) create mode 100644 src/goose/toolkit/vector.py create mode 100644 tests/toolkit/test_vector.py diff --git a/pyproject.toml b/pyproject.toml index 18bff707..eb0c78e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ dependencies = [ "ai-exchange>=0.8.0", "click>=8.1.7", "prompt-toolkit>=3.0.47", + "sentence-transformers>=3.0.1", + "torch>=2.4.0", ] author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] packages = [{ include = "goose", from = "src" }] diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py new file mode 100644 index 00000000..1eddac8b --- /dev/null +++ b/src/goose/toolkit/vector.py @@ -0,0 +1,94 @@ +import os +import tempfile +import torch +import uuid +import hashlib +from goose.toolkit.base import Toolkit, tool +from sentence_transformers import SentenceTransformer, util + +class VectorToolkit(Toolkit): + def __init__(self, notifier): + super().__init__(notifier) + self.model = SentenceTransformer('all-MiniLM-L6-v2') + + def get_db_path(self, repo_path): + # Create a hash of the repo path + repo_hash = hashlib.md5(repo_path.encode()).hexdigest() + return os.path.join(tempfile.gettempdir(), f'code_vectors_{repo_hash}.pt') + + @tool + def create_vector_db(self, repo_path: str) -> str: + """ + Create a vector database of the code in the specified directory and store it in a temp file. + + Args: + repo_path (str): Path to the source code repository. + + Returns: + str: Path to the created vector database file. + """ + temp_db_path = self.get_db_path(repo_path) + file_paths, file_contents = self.scan_repository(repo_path) + print("Scanned File Paths:", file_paths) + embeddings = self.build_vector_database(file_contents) + self.save_vector_database(file_paths, embeddings, temp_db_path) + return f'Vector database created at {temp_db_path}' + + @tool + def query_vector_db(self, repo_path: str, query: str) -> str: + """ + Query the vector database with the provided string and return similar files. + + Args: + query (str): Query string to search in the database. + + Returns: + str: List of similar files found in the vector database. + """ + temp_db_path = self.get_db_path(repo_path) + file_paths, embeddings = self.load_vector_database(temp_db_path) + print("File Paths:", file_paths) + print("Embeddings Size:", embeddings.size()) + similar_files = self.find_similar_files(query, file_paths, embeddings) + return '\n'.join(similar_files) + + def scan_repository(self, repo_path): + print(f'Scanning repository at: {repo_path}') + file_contents = [] + file_paths = [] + for root, dirs, files in os.walk(repo_path): + for file in files: + if file.endswith(('.py', '.java', '.js', '.cpp', '.c', '.h', '.rb', '.go', '.rs', '.php', '.html', '.css', '.md', '.dart')): + file_path = os.path.join(root, file) + file_paths.append(file_path) + try: + with open(file_path, 'r', errors='ignore') as f: + content = f.read() + file_contents.append(content) + except Exception as e: + print(f'Error reading {file_path}: {e}') + return file_paths, file_contents + + def build_vector_database(self, file_contents): + embeddings = self.model.encode(file_contents, convert_to_tensor=True) + return embeddings + + def save_vector_database(self, file_paths, embeddings, db_path): + torch.save({'file_paths': file_paths, 'embeddings': embeddings}, db_path) + + def load_vector_database(self, db_path): + data = torch.load(db_path) + return data['file_paths'], data['embeddings'] + + def find_similar_files(self, query, file_paths, embeddings): + query_embedding = self.model.encode([query], convert_to_tensor=True) + if embeddings.size(0) == 0: + return 'No embeddings available to query against' + scores = util.pytorch_cos_sim(query_embedding, embeddings)[0] + top_results = torch.topk(scores, k=10) + similar_files = [file_paths[idx] for idx in top_results[1]] + return similar_files + + def system(self) -> str: + return """**When the user wants to create a vector database or query an existing one, use the create_vector_db and query_vector_db tools respectively.**""" + diff --git a/tests/toolkit/test_vector.py b/tests/toolkit/test_vector.py new file mode 100644 index 00000000..8b34089f --- /dev/null +++ b/tests/toolkit/test_vector.py @@ -0,0 +1,42 @@ +from pathlib import Path +import os +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock + +import pytest +from goose.toolkit.vector import VectorToolkit + +@pytest.fixture +def temp_dir(): + with TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + +@pytest.fixture +def vector_toolkit(): + return VectorToolkit(notifier=MagicMock()) + +def test_create_vector_db(temp_dir, vector_toolkit): + # Create some test files + (temp_dir / 'test1.py').write_text('print("Hello World")') + (temp_dir / 'test2.py').write_text('def foo():\n return "bar"') + + print(f"Created test files in: {temp_dir}") + for path in temp_dir.glob('*'): + print(f"- {path.name}") + + temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) + + result = vector_toolkit.create_vector_db(temp_dir.as_posix()) + assert 'Vector database created at' in result + assert os.path.exists(temp_db_path) + assert os.path.getsize(temp_db_path) > 0 + +def test_query_vector_db(temp_dir, vector_toolkit): + # Create and load a vector database + vector_toolkit.create_vector_db(temp_dir.as_posix()) + query = 'print("Hello World")' + result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query) + print("Query Result:", result) + assert isinstance(result, str) + # Ensure no exception and the result is handled gracefully + assert 'No embeddings available to query against' in result or '\n' in result From e16e4ee15e10d4779814f703d6169f464627dd47 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 17:30:50 +1000 Subject: [PATCH 02/23] add plugin --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eb0c78e5..c6e7bb1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ developer = "goose.toolkit.developer:Developer" github = "goose.toolkit.github:Github" screen = "goose.toolkit.screen:Screen" repo_context = "goose.toolkit.repo_context.repo_context:RepoContext" +vector = "goose.toolkit.vector:VectorToolkit" [project.entry-points."goose.profile"] default = "goose.profile:default_profile" @@ -47,4 +48,3 @@ dev-dependencies = [ "pytest>=8.3.2", "codecov>=2.1.13", ] - From 1439926f05f813860ce16ee89e3579646fe8ea06 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 17:34:17 +1000 Subject: [PATCH 03/23] status for vectors --- src/goose/toolkit/vector.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 1eddac8b..a6257a2c 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -5,6 +5,7 @@ import hashlib from goose.toolkit.base import Toolkit, tool from sentence_transformers import SentenceTransformer, util +from goose.cli.session import SessionNotifier class VectorToolkit(Toolkit): def __init__(self, notifier): @@ -28,10 +29,14 @@ def create_vector_db(self, repo_path: str) -> str: str: Path to the created vector database file. """ temp_db_path = self.get_db_path(repo_path) + self.notifier.status("Scanning repository...") file_paths, file_contents = self.scan_repository(repo_path) + self.notifier.status("Building vector database...") print("Scanned File Paths:", file_paths) embeddings = self.build_vector_database(file_contents) + self.notifier.status("Saving vector database...") self.save_vector_database(file_paths, embeddings, temp_db_path) + self.notifier.status("Completed vector database creation") return f'Vector database created at {temp_db_path}' @tool @@ -46,10 +51,13 @@ def query_vector_db(self, repo_path: str, query: str) -> str: str: List of similar files found in the vector database. """ temp_db_path = self.get_db_path(repo_path) + self.notifier.status("Loading vector database...") file_paths, embeddings = self.load_vector_database(temp_db_path) print("File Paths:", file_paths) print("Embeddings Size:", embeddings.size()) + self.notifier.status("Performing query...") similar_files = self.find_similar_files(query, file_paths, embeddings) + self.notifier.status("Query completed") return '\n'.join(similar_files) def scan_repository(self, repo_path): From 1f6742ea36e872051938b8d69638c284603154f1 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 17:53:14 +1000 Subject: [PATCH 04/23] progress --- src/goose/toolkit/vector.py | 16 +++++++++------- tests/toolkit/test_vector.py | 26 +++++++++++--------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index a6257a2c..f593dd06 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -17,7 +17,6 @@ def get_db_path(self, repo_path): repo_hash = hashlib.md5(repo_path.encode()).hexdigest() return os.path.join(tempfile.gettempdir(), f'code_vectors_{repo_hash}.pt') - @tool def create_vector_db(self, repo_path: str) -> str: """ Create a vector database of the code in the specified directory and store it in a temp file. @@ -37,20 +36,23 @@ def create_vector_db(self, repo_path: str) -> str: self.notifier.status("Saving vector database...") self.save_vector_database(file_paths, embeddings, temp_db_path) self.notifier.status("Completed vector database creation") - return f'Vector database created at {temp_db_path}' + return temp_db_path @tool def query_vector_db(self, repo_path: str, query: str) -> str: """ - Query the vector database with the provided string and return similar files. + Locate files in a repository that are potentially semantically related to the query and may hint where to look. Args: - query (str): Query string to search in the database. + query (str): Query string to search for semantically related files or paths. Returns: - str: List of similar files found in the vector database. + str: List of semantically relevant files to look in, also consider the paths the files are in. """ temp_db_path = self.get_db_path(repo_path) + if not os.path.exists(temp_db_path): + self.notifier.status("Vector database not found. Creating vector database...") + self.create_vector_db(repo_path) self.notifier.status("Loading vector database...") file_paths, embeddings = self.load_vector_database(temp_db_path) print("File Paths:", file_paths) @@ -66,7 +68,7 @@ def scan_repository(self, repo_path): file_paths = [] for root, dirs, files in os.walk(repo_path): for file in files: - if file.endswith(('.py', '.java', '.js', '.cpp', '.c', '.h', '.rb', '.go', '.rs', '.php', '.html', '.css', '.md', '.dart')): + if file.endswith(('.py', '.java', '.js', '.cpp', '.c', '.h', '.rb', '.go', '.rs', '.php', '.css', '.md', '.dart')): file_path = os.path.join(root, file) file_paths.append(file_path) try: @@ -98,5 +100,5 @@ def find_similar_files(self, query, file_paths, embeddings): return similar_files def system(self) -> str: - return """**When the user wants to create a vector database or query an existing one, use the create_vector_db and query_vector_db tools respectively.**""" + return """**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**""" diff --git a/tests/toolkit/test_vector.py b/tests/toolkit/test_vector.py index 8b34089f..461fd416 100644 --- a/tests/toolkit/test_vector.py +++ b/tests/toolkit/test_vector.py @@ -15,28 +15,24 @@ def temp_dir(): def vector_toolkit(): return VectorToolkit(notifier=MagicMock()) -def test_create_vector_db(temp_dir, vector_toolkit): - # Create some test files - (temp_dir / 'test1.py').write_text('print("Hello World")') - (temp_dir / 'test2.py').write_text('def foo():\n return "bar"') - - print(f"Created test files in: {temp_dir}") - for path in temp_dir.glob('*'): - print(f"- {path.name}") - +def test_query_vector_db_creates_db(temp_dir, vector_toolkit): + # Create and load a vector database lazily + query = 'print("Hello World")' + result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query) + print("Query Result:", result) + assert isinstance(result, str) temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) - - result = vector_toolkit.create_vector_db(temp_dir.as_posix()) - assert 'Vector database created at' in result assert os.path.exists(temp_db_path) assert os.path.getsize(temp_db_path) > 0 def test_query_vector_db(temp_dir, vector_toolkit): - # Create and load a vector database + # Create initial db vector_toolkit.create_vector_db(temp_dir.as_posix()) query = 'print("Hello World")' result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query) print("Query Result:", result) assert isinstance(result, str) - # Ensure no exception and the result is handled gracefully - assert 'No embeddings available to query against' in result or '\n' in result + temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) + assert os.path.exists(temp_db_path) + assert os.path.getsize(temp_db_path) > 0 + assert 'No embeddings available to query against' in result or '\n' in result \ No newline at end of file From 3c7a5f20071d6413f6d58abc56bfb99251717810 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 18:04:30 +1000 Subject: [PATCH 05/23] working but messy --- src/goose/toolkit/vector.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index f593dd06..a9616474 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -1,16 +1,13 @@ import os import tempfile import torch -import uuid import hashlib from goose.toolkit.base import Toolkit, tool from sentence_transformers import SentenceTransformer, util from goose.cli.session import SessionNotifier class VectorToolkit(Toolkit): - def __init__(self, notifier): - super().__init__(notifier) - self.model = SentenceTransformer('all-MiniLM-L6-v2') + def get_db_path(self, repo_path): # Create a hash of the repo path @@ -44,8 +41,8 @@ def query_vector_db(self, repo_path: str, query: str) -> str: Locate files in a repository that are potentially semantically related to the query and may hint where to look. Args: - query (str): Query string to search for semantically related files or paths. - + repo_path (str): The repository that we will be searching in + query (str): Query string to search for semantically related files or paths. Returns: str: List of semantically relevant files to look in, also consider the paths the files are in. """ @@ -100,5 +97,6 @@ def find_similar_files(self, query, file_paths, embeddings): return similar_files def system(self) -> str: + self.model = SentenceTransformer('all-MiniLM-L6-v2') return """**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**""" From 2b95bd528dfd9674a51c506ad31ef785e40a56b6 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 18:06:19 +1000 Subject: [PATCH 06/23] tidier --- src/goose/toolkit/vector.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index a9616474..81a5ba23 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -28,7 +28,6 @@ def create_vector_db(self, repo_path: str) -> str: self.notifier.status("Scanning repository...") file_paths, file_contents = self.scan_repository(repo_path) self.notifier.status("Building vector database...") - print("Scanned File Paths:", file_paths) embeddings = self.build_vector_database(file_contents) self.notifier.status("Saving vector database...") self.save_vector_database(file_paths, embeddings, temp_db_path) @@ -52,15 +51,12 @@ def query_vector_db(self, repo_path: str, query: str) -> str: self.create_vector_db(repo_path) self.notifier.status("Loading vector database...") file_paths, embeddings = self.load_vector_database(temp_db_path) - print("File Paths:", file_paths) - print("Embeddings Size:", embeddings.size()) self.notifier.status("Performing query...") similar_files = self.find_similar_files(query, file_paths, embeddings) self.notifier.status("Query completed") return '\n'.join(similar_files) def scan_repository(self, repo_path): - print(f'Scanning repository at: {repo_path}') file_contents = [] file_paths = [] for root, dirs, files in os.walk(repo_path): From cf86d5c552d8b406e0971b8c790d264cda14c899 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 18:37:39 +1000 Subject: [PATCH 07/23] set model --- src/goose/toolkit/vector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 81a5ba23..019e4c9e 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -9,6 +9,8 @@ class VectorToolkit(Toolkit): + model = SentenceTransformer('all-MiniLM-L6-v2') + def get_db_path(self, repo_path): # Create a hash of the repo path repo_hash = hashlib.md5(repo_path.encode()).hexdigest() @@ -93,6 +95,6 @@ def find_similar_files(self, query, file_paths, embeddings): return similar_files def system(self) -> str: - self.model = SentenceTransformer('all-MiniLM-L6-v2') + return """**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**""" From d6ddec8a9f865c8a20c9d1802722a88a37a89b67 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 18:48:52 +1000 Subject: [PATCH 08/23] add more language extension types --- src/goose/toolkit/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 019e4c9e..ac4d53d4 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -63,7 +63,7 @@ def scan_repository(self, repo_path): file_paths = [] for root, dirs, files in os.walk(repo_path): for file in files: - if file.endswith(('.py', '.java', '.js', '.cpp', '.c', '.h', '.rb', '.go', '.rs', '.php', '.css', '.md', '.dart')): + if file.endswith(('.py', '.java', '.js', '.cpp', '.c', '.h', '.rb', '.go', '.rs', '.php', '.css', '.md', '.dart', '.kt', '.ts', '.yaml', '.yml')): file_path = os.path.join(root, file) file_paths.append(file_path) try: From 8285b90eb6d4dd64ab9ceaf387737d7aae6ec338 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 18:52:25 +1000 Subject: [PATCH 09/23] set cleanup tokenization spaces to true --- src/goose/toolkit/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index ac4d53d4..0cde6188 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -9,7 +9,7 @@ class VectorToolkit(Toolkit): - model = SentenceTransformer('all-MiniLM-L6-v2') + model = SentenceTransformer('all-MiniLM-L6-v2', tokenizer_kwargs={"clean_up_tokenization_spaces": True} ) def get_db_path(self, repo_path): # Create a hash of the repo path From 63de6b38951faa8d34549dff4751b49f5be99438 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 28 Aug 2024 19:02:02 +1000 Subject: [PATCH 10/23] small fixes for tensorflow future compatibility --- src/goose/toolkit/vector.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 0cde6188..181f6470 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -5,16 +5,21 @@ from goose.toolkit.base import Toolkit, tool from sentence_transformers import SentenceTransformer, util from goose.cli.session import SessionNotifier +from pathlib import Path + + +GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() +VECTOR_PATH = GOOSE_GLOBAL_PATH.joinpath("vectors") + class VectorToolkit(Toolkit): - - model = SentenceTransformer('all-MiniLM-L6-v2', tokenizer_kwargs={"clean_up_tokenization_spaces": True} ) + model = SentenceTransformer('all-MiniLM-L6-v2', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) def get_db_path(self, repo_path): # Create a hash of the repo path repo_hash = hashlib.md5(repo_path.encode()).hexdigest() - return os.path.join(tempfile.gettempdir(), f'code_vectors_{repo_hash}.pt') + return VECTOR_PATH.joinpath(f'code_vectors_{repo_hash}.pt') def create_vector_db(self, repo_path: str) -> str: """ @@ -27,6 +32,7 @@ def create_vector_db(self, repo_path: str) -> str: str: Path to the created vector database file. """ temp_db_path = self.get_db_path(repo_path) + VECTOR_PATH.mkdir(parents=True, exist_ok=True) self.notifier.status("Scanning repository...") file_paths, file_contents = self.scan_repository(repo_path) self.notifier.status("Building vector database...") @@ -82,7 +88,7 @@ def save_vector_database(self, file_paths, embeddings, db_path): torch.save({'file_paths': file_paths, 'embeddings': embeddings}, db_path) def load_vector_database(self, db_path): - data = torch.load(db_path) + data = torch.load(db_path, weights_only=True) return data['file_paths'], data['embeddings'] def find_similar_files(self, query, file_paths, embeddings): @@ -96,5 +102,6 @@ def find_similar_files(self, query, file_paths, embeddings): def system(self) -> str: - return """**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**""" + return "**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**" + From 0635c5f4d98586ebba5c84a296083b656c587997 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Thu, 29 Aug 2024 19:46:48 +1000 Subject: [PATCH 11/23] now is working --- src/goose/toolkit/vector.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 181f6470..8af60856 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -33,9 +33,9 @@ def create_vector_db(self, repo_path: str) -> str: """ temp_db_path = self.get_db_path(repo_path) VECTOR_PATH.mkdir(parents=True, exist_ok=True) - self.notifier.status("Scanning repository...") + self.notifier.status("Preparing vector database :: Scanning repository (first time may take a while, please wait...)") file_paths, file_contents = self.scan_repository(repo_path) - self.notifier.status("Building vector database...") + self.notifier.status("Preparing vector database :: Building vectors (first time may take a while, please wait...)") embeddings = self.build_vector_database(file_contents) self.notifier.status("Saving vector database...") self.save_vector_database(file_paths, embeddings, temp_db_path) @@ -55,21 +55,24 @@ def query_vector_db(self, repo_path: str, query: str) -> str: """ temp_db_path = self.get_db_path(repo_path) if not os.path.exists(temp_db_path): - self.notifier.status("Vector database not found. Creating vector database...") self.create_vector_db(repo_path) self.notifier.status("Loading vector database...") file_paths, embeddings = self.load_vector_database(temp_db_path) self.notifier.status("Performing query...") - similar_files = self.find_similar_files(query, file_paths, embeddings) - self.notifier.status("Query completed") + similar_files = self.find_similar_files(query, file_paths, embeddings) return '\n'.join(similar_files) def scan_repository(self, repo_path): + repo_path = Path(repo_path).expanduser() file_contents = [] file_paths = [] + skipped_file_types = {} for root, dirs, files in os.walk(repo_path): + # Exclude dotfile directories + dirs[:] = [d for d in dirs if not d.startswith('.')] for file in files: - if file.endswith(('.py', '.java', '.js', '.cpp', '.c', '.h', '.rb', '.go', '.rs', '.php', '.css', '.md', '.dart', '.kt', '.ts', '.yaml', '.yml')): + file_extension = os.path.splitext(file)[1] + if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.css', '.scss', '.less', '.md', '.dart', '.kt', '.swift', '.scala', '.sql', '.sh', '.bash', '.yaml', '.yml', '.json', '.xml', '.html', '.vue', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.f90', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.ex', '.clj', '.lisp', '.ml', '.nim']: file_path = os.path.join(root, file) file_paths.append(file_path) try: @@ -78,8 +81,10 @@ def scan_repository(self, repo_path): file_contents.append(content) except Exception as e: print(f'Error reading {file_path}: {e}') + else: + skipped_file_types[file_extension] = True return file_paths, file_contents - + def build_vector_database(self, file_contents): embeddings = self.model.encode(file_contents, convert_to_tensor=True) return embeddings From 71cb30373fa47664070e258baefea82aa5c0e861 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Thu, 29 Aug 2024 21:01:23 +1000 Subject: [PATCH 12/23] shifting to faster model --- src/goose/toolkit/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 8af60856..ca5f7f44 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -14,7 +14,7 @@ class VectorToolkit(Toolkit): - model = SentenceTransformer('all-MiniLM-L6-v2', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) + model = SentenceTransformer('paraphrase-albert-small-v2', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) def get_db_path(self, repo_path): # Create a hash of the repo path From 2418b99b95fb9617211bba55971608b6db600c8e Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Fri, 30 Aug 2024 09:47:38 +1000 Subject: [PATCH 13/23] trimming list of files --- src/goose/toolkit/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index ca5f7f44..dddb1bbd 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -72,7 +72,7 @@ def scan_repository(self, repo_path): dirs[:] = [d for d in dirs if not d.startswith('.')] for file in files: file_extension = os.path.splitext(file)[1] - if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.css', '.scss', '.less', '.md', '.dart', '.kt', '.swift', '.scala', '.sql', '.sh', '.bash', '.yaml', '.yml', '.json', '.xml', '.html', '.vue', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.f90', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.ex', '.clj', '.lisp', '.ml', '.nim']: + if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.css', '.scss', '.less', '.md', '.dart', '.kt', '.swift', '.scala', '.html', '.vue', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.clj', '.lisp']: file_path = os.path.join(root, file) file_paths.append(file_path) try: From eff8ee1fd7eda6a290b553ab73ee9c4410979a34 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Fri, 30 Aug 2024 17:09:32 +1000 Subject: [PATCH 14/23] only create vector db once her file heirarchy --- src/goose/toolkit/vector.py | 34 ++++++++++++++++++++++++++-------- tests/toolkit/test_vector.py | 25 ++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index dddb1bbd..419d7676 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -53,15 +53,33 @@ def query_vector_db(self, repo_path: str, query: str) -> str: Returns: str: List of semantically relevant files to look in, also consider the paths the files are in. """ - temp_db_path = self.get_db_path(repo_path) - if not os.path.exists(temp_db_path): - self.create_vector_db(repo_path) + temp_db_path = self.lookup_db_path(repo_path) + if temp_db_path is None: + temp_db_path = self.create_vector_db(repo_path) self.notifier.status("Loading vector database...") file_paths, embeddings = self.load_vector_database(temp_db_path) self.notifier.status("Performing query...") similar_files = self.find_similar_files(query, file_paths, embeddings) return '\n'.join(similar_files) + def lookup_db_path(self, repo_path: str) -> str: + """ + Check if a vector database exists for the given repository path or its parent directories. + + Args: + repo_path (str): Path to the source code repository. + + Returns: + str: Path to the existing vector database file, or None if none found. + """ + current_path = Path(repo_path).expanduser() + while current_path != current_path.parent: + temp_db_path = self.get_db_path(str(current_path)) + if os.path.exists(temp_db_path): + return temp_db_path + current_path = current_path.parent + return None + def scan_repository(self, repo_path): repo_path = Path(repo_path).expanduser() file_contents = [] @@ -72,7 +90,7 @@ def scan_repository(self, repo_path): dirs[:] = [d for d in dirs if not d.startswith('.')] for file in files: file_extension = os.path.splitext(file)[1] - if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.css', '.scss', '.less', '.md', '.dart', '.kt', '.swift', '.scala', '.html', '.vue', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.clj', '.lisp']: + if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.md', '.dart', '.kt', '.swift', '.scala', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.clj', '.lisp']: file_path = os.path.join(root, file) file_paths.append(file_path) try: @@ -93,7 +111,10 @@ def save_vector_database(self, file_paths, embeddings, db_path): torch.save({'file_paths': file_paths, 'embeddings': embeddings}, db_path) def load_vector_database(self, db_path): - data = torch.load(db_path, weights_only=True) + if db_path is not None and os.path.exists(db_path): + data = torch.load(db_path) + else: + raise ValueError(f"Database path {db_path} does not exist.") return data['file_paths'], data['embeddings'] def find_similar_files(self, query, file_paths, embeddings): @@ -106,7 +127,4 @@ def find_similar_files(self, query, file_paths, embeddings): return similar_files def system(self) -> str: - return "**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**" - - diff --git a/tests/toolkit/test_vector.py b/tests/toolkit/test_vector.py index 461fd416..61119c0d 100644 --- a/tests/toolkit/test_vector.py +++ b/tests/toolkit/test_vector.py @@ -25,6 +25,7 @@ def test_query_vector_db_creates_db(temp_dir, vector_toolkit): assert os.path.exists(temp_db_path) assert os.path.getsize(temp_db_path) > 0 + def test_query_vector_db(temp_dir, vector_toolkit): # Create initial db vector_toolkit.create_vector_db(temp_dir.as_posix()) @@ -35,4 +36,26 @@ def test_query_vector_db(temp_dir, vector_toolkit): temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) assert os.path.exists(temp_db_path) assert os.path.getsize(temp_db_path) > 0 - assert 'No embeddings available to query against' in result or '\n' in result \ No newline at end of file + assert 'No embeddings available to query against' in result or '\n' in result + + +def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): + # Create a vector DB at a higher level + higher_dir = temp_dir / "higher" + higher_dir.mkdir() + db_path_higher = vector_toolkit.create_vector_db(higher_dir.as_posix()) + + # Now create a lower directory + lower_dir = higher_dir / "lower" + lower_dir.mkdir() + + # Perform query on the lower directory + query = 'print("Hello World")' + result = vector_toolkit.query_vector_db(lower_dir.as_posix(), query) + print("Query Result from Lower Directory:", result) + + # Ensure a DB at the lower level is not created + temp_db_path_lower = vector_toolkit.get_db_path(lower_dir.as_posix()) + assert not os.path.exists(temp_db_path_lower) + assert os.path.exists(db_path_higher) + assert os.path.getsize(db_path_higher) > 0 From f5f93b9ac58fedbb73d5f1fde2d7a857706f6c4c Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Fri, 30 Aug 2024 19:09:47 +1000 Subject: [PATCH 15/23] switching models for more semantic similarity --- src/goose/toolkit/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 419d7676..eaf3bfdc 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -14,7 +14,7 @@ class VectorToolkit(Toolkit): - model = SentenceTransformer('paraphrase-albert-small-v2', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) + model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) def get_db_path(self, repo_path): # Create a hash of the repo path From bd2d2b6f811d580d02435168f02da59cbf6f6d89 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Mon, 2 Sep 2024 15:39:26 +1000 Subject: [PATCH 16/23] restoring weights only --- src/goose/toolkit/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index eaf3bfdc..af0251c2 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -112,7 +112,7 @@ def save_vector_database(self, file_paths, embeddings, db_path): def load_vector_database(self, db_path): if db_path is not None and os.path.exists(db_path): - data = torch.load(db_path) + data = torch.load(db_path, weights_only=True) else: raise ValueError(f"Database path {db_path} does not exist.") return data['file_paths'], data['embeddings'] From 46443d59115a2781b6d643c31d17156e018c79c8 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Tue, 3 Sep 2024 17:57:27 +1000 Subject: [PATCH 17/23] make it smarter with embeddings and paths --- src/goose/toolkit/prompts/vector.jinja | 9 +++++ src/goose/toolkit/vector.py | 46 ++++++++++++++++++-------- tests/toolkit/test_vector.py | 38 +++++++++++++++++++-- 3 files changed, 77 insertions(+), 16 deletions(-) create mode 100644 src/goose/toolkit/prompts/vector.jinja diff --git a/src/goose/toolkit/prompts/vector.jinja b/src/goose/toolkit/prompts/vector.jinja new file mode 100644 index 00000000..935e328e --- /dev/null +++ b/src/goose/toolkit/prompts/vector.jinja @@ -0,0 +1,9 @@ +When working in an existing code repository, this tool can help with finding similar or relevant files and paths. +This is useful when the user wants to search for things around concepts or similar items. + +To accomplish this, you can use the find_simmilar_files_locations tool which will return files and paths that may be relevant to inform the query. +These may also be values which can be further searched for relevance. Searching this way looks for things that +are similar in concept to what was asked (or close). + +If you need to search for what looks like an exact string, use another tool and approach. +If you need to search for where to start or a concept, consider the find_simmilar_files_locations tool. \ No newline at end of file diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index af0251c2..96dff2b1 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -1,11 +1,13 @@ import os -import tempfile import torch import hashlib from goose.toolkit.base import Toolkit, tool from sentence_transformers import SentenceTransformer, util from goose.cli.session import SessionNotifier from pathlib import Path +from exchange import Message +import os + GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() @@ -14,7 +16,15 @@ class VectorToolkit(Toolkit): - model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) + _model = None + + @property + def model(self): + if self._model is None: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self.notifier.status("Preparing local model...") + self._model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) + return self._model def get_db_path(self, repo_path): # Create a hash of the repo path @@ -33,30 +43,30 @@ def create_vector_db(self, repo_path: str) -> str: """ temp_db_path = self.get_db_path(repo_path) VECTOR_PATH.mkdir(parents=True, exist_ok=True) - self.notifier.status("Preparing vector database :: Scanning repository (first time may take a while, please wait...)") + self.notifier.status("Scanning repository (first time may take a while, please wait...)") file_paths, file_contents = self.scan_repository(repo_path) - self.notifier.status("Preparing vector database :: Building vectors (first time may take a while, please wait...)") + self.notifier.status("Building local emebddings of code (first time may take a while, please wait...)") embeddings = self.build_vector_database(file_contents) - self.notifier.status("Saving vector database...") self.save_vector_database(file_paths, embeddings, temp_db_path) - self.notifier.status("Completed vector database creation") return temp_db_path @tool - def query_vector_db(self, repo_path: str, query: str) -> str: + def find_simmilar_files_locations(self, repo_path: str, query: str) -> str: """ - Locate files in a repository that are potentially semantically related to the query and may hint where to look. + Locate files and locations in a repository that are potentially semantically related to the query and may hint where to look. + This will not be the only location, but can serve as a starting point. + Note that they will probably not be exact matches, so other tools for text searching can be used if needed. Args: repo_path (str): The repository that we will be searching in query (str): Query string to search for semantically related files or paths. Returns: - str: List of semantically relevant files to look in, also consider the paths the files are in. + str: List of semantically relevant files and paths to consider. """ temp_db_path = self.lookup_db_path(repo_path) if temp_db_path is None: temp_db_path = self.create_vector_db(repo_path) - self.notifier.status("Loading vector database...") + self.notifier.status("Loading embeddings database...") file_paths, embeddings = self.load_vector_database(temp_db_path) self.notifier.status("Performing query...") similar_files = self.find_similar_files(query, file_paths, embeddings) @@ -122,9 +132,19 @@ def find_similar_files(self, query, file_paths, embeddings): if embeddings.size(0) == 0: return 'No embeddings available to query against' scores = util.pytorch_cos_sim(query_embedding, embeddings)[0] - top_results = torch.topk(scores, k=10) + top_results = torch.topk(scores, k=min(10, scores.size(0))) similar_files = [file_paths[idx] for idx in top_results[1]] - return similar_files + expanded_similar_files = set() + for file in similar_files: + expanded_similar_files.add(file) + parent = Path(file).parent + depth = 0 + while parent != parent.parent and depth < 3: + expanded_similar_files.add(str(parent)) + parent = parent.parent + depth += 1 + return list(expanded_similar_files) def system(self) -> str: - return "**When looking at a large repository for relevant files or paths to examine related semantically to the question, use the query_vector_db tool**" + """Retrieve guidelines for semantic search""" + return Message.load("prompts/vector.jinja").text diff --git a/tests/toolkit/test_vector.py b/tests/toolkit/test_vector.py index 61119c0d..825c1aa8 100644 --- a/tests/toolkit/test_vector.py +++ b/tests/toolkit/test_vector.py @@ -18,7 +18,7 @@ def vector_toolkit(): def test_query_vector_db_creates_db(temp_dir, vector_toolkit): # Create and load a vector database lazily query = 'print("Hello World")' - result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query) + result = vector_toolkit.find_simmilar_files_locations(temp_dir.as_posix(), query) print("Query Result:", result) assert isinstance(result, str) temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) @@ -30,7 +30,7 @@ def test_query_vector_db(temp_dir, vector_toolkit): # Create initial db vector_toolkit.create_vector_db(temp_dir.as_posix()) query = 'print("Hello World")' - result = vector_toolkit.query_vector_db(temp_dir.as_posix(), query) + result = vector_toolkit.find_simmilar_files_locations(temp_dir.as_posix(), query) print("Query Result:", result) assert isinstance(result, str) temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) @@ -51,7 +51,7 @@ def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): # Perform query on the lower directory query = 'print("Hello World")' - result = vector_toolkit.query_vector_db(lower_dir.as_posix(), query) + result = vector_toolkit.find_simmilar_files_locations(lower_dir.as_posix(), query) print("Query Result from Lower Directory:", result) # Ensure a DB at the lower level is not created @@ -59,3 +59,35 @@ def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): assert not os.path.exists(temp_db_path_lower) assert os.path.exists(db_path_higher) assert os.path.getsize(db_path_higher) > 0 + + +def test_find_similar_files_in_repo(temp_dir, vector_toolkit): + # Setting up a temporary repository structure + file_structure = { + 'file1.py': 'def function_one(): pass\n', + 'file2.py': 'def function_two(): pass\n', + 'subdir': { + 'file3.py': 'class MyClass: pass\n' + } + } + + def create_files(base_path, structure): + for name, content in structure.items(): + path = base_path / name + if isinstance(content, str): + with open(path, 'w') as f: + f.write(content) + else: + path.mkdir() + create_files(path, content) + + create_files(temp_dir, file_structure) + + # Create initial db + vector_toolkit.create_vector_db(temp_dir.as_posix()) + + # Test query + query = 'def function_one' + result = vector_toolkit.find_simmilar_files_locations(temp_dir.as_posix(), query) + print("Similar Files Result:", result) + assert 'file1.py' in result or 'subdir' in result From 2374bdd72eb20a96dc6f66c50e33077ea85843cd Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Thu, 12 Sep 2024 16:37:39 +1000 Subject: [PATCH 18/23] use faiss and better matching --- .goosehints | 3 ++ pyproject.toml | 1 + src/goose/toolkit/prompts/vector.jinja | 11 ++---- src/goose/toolkit/vector.py | 48 ++++++++++++++++++-------- tests/toolkit/test_vector.py | 8 ++--- 5 files changed, 45 insertions(+), 26 deletions(-) create mode 100644 .goosehints diff --git a/.goosehints b/.goosehints new file mode 100644 index 00000000..8b6535a6 --- /dev/null +++ b/.goosehints @@ -0,0 +1,3 @@ +This is a python CLI app that uses UV. Read CONTRIBUTING.md for information on how to build and test it as needed. +Some key concepts are that it is run as a command line interface, dependes on the "ai-exchange" package, and has the concept of toolkits which are ways that its behavior can be extended. Look in src/goose and tests. +Once the user has UV installed it should be able to be used effectively along with uvx to run tasks as needed diff --git a/pyproject.toml b/pyproject.toml index c6e7bb1e..32c02752 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "prompt-toolkit>=3.0.47", "sentence-transformers>=3.0.1", "torch>=2.4.0", + "faiss-cpu" ] author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] packages = [{ include = "goose", from = "src" }] diff --git a/src/goose/toolkit/prompts/vector.jinja b/src/goose/toolkit/prompts/vector.jinja index 935e328e..98b38274 100644 --- a/src/goose/toolkit/prompts/vector.jinja +++ b/src/goose/toolkit/prompts/vector.jinja @@ -1,9 +1,4 @@ -When working in an existing code repository, this tool can help with finding similar or relevant files and paths. -This is useful when the user wants to search for things around concepts or similar items. +When navigating a codebase, it can be useful to search for conceptually related items when needed. -To accomplish this, you can use the find_simmilar_files_locations tool which will return files and paths that may be relevant to inform the query. -These may also be values which can be further searched for relevance. Searching this way looks for things that -are similar in concept to what was asked (or close). - -If you need to search for what looks like an exact string, use another tool and approach. -If you need to search for where to start or a concept, consider the find_simmilar_files_locations tool. \ No newline at end of file +To do this you can use the find_similar_files_locations tool. It will provide potentially related paths and files to look at. +This should not be relied on alone, but in tandem with other tools \ No newline at end of file diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 96dff2b1..4d77c5d9 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -6,8 +6,7 @@ from goose.cli.session import SessionNotifier from pathlib import Path from exchange import Message -import os - +import faiss GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() @@ -15,9 +14,9 @@ class VectorToolkit(Toolkit): - + _model = None - + @property def model(self): if self._model is None: @@ -51,15 +50,34 @@ def create_vector_db(self, repo_path: str) -> str: return temp_db_path @tool - def find_simmilar_files_locations(self, repo_path: str, query: str) -> str: + def find_similar_files_locations(self, repo_path: str, query: str) -> str: + """ + Locate files and locations in a repository that are conceptually related to the query and may hint where to look. + Don't rely on this for exhaustive matches around strings, use ripgrep additionally for searching. + + Args: + repo_path (str): The repository that we will be searching in + query (str): Query string to search for semantically related files or paths. + Returns: + str: List of semantically relevant files and paths to consider. + """ + temp_db_path = self.lookup_db_path(repo_path) + if temp_db_path is None: + temp_db_path = self.create_vector_db(repo_path) + self.notifier.status("Loading embeddings database...") + file_paths, embeddings = self.load_vector_database(temp_db_path) + self.notifier.status("Performing query...") + similar_files = self.find_similar_files(query, file_paths, embeddings) + return '\n'.join(similar_files) """ Locate files and locations in a repository that are potentially semantically related to the query and may hint where to look. - This will not be the only location, but can serve as a starting point. - Note that they will probably not be exact matches, so other tools for text searching can be used if needed. + This will not be the only location, but can serve as a starting point. + Note that they will probably not be exact matches, so other tools for text searching can be used as well. + Don't rely on this for exhaustive matches around keywords, it is about concepts. Args: repo_path (str): The repository that we will be searching in - query (str): Query string to search for semantically related files or paths. + query (str): Query string to search for semantically related files or paths. Returns: str: List of semantically relevant files and paths to consider. """ @@ -69,7 +87,7 @@ def find_simmilar_files_locations(self, repo_path: str, query: str) -> str: self.notifier.status("Loading embeddings database...") file_paths, embeddings = self.load_vector_database(temp_db_path) self.notifier.status("Performing query...") - similar_files = self.find_similar_files(query, file_paths, embeddings) + similar_files = self.find_similar_files(query, file_paths, embeddings) return '\n'.join(similar_files) def lookup_db_path(self, repo_path: str) -> str: @@ -112,7 +130,7 @@ def scan_repository(self, repo_path): else: skipped_file_types[file_extension] = True return file_paths, file_contents - + def build_vector_database(self, file_contents): embeddings = self.model.encode(file_contents, convert_to_tensor=True) return embeddings @@ -128,12 +146,14 @@ def load_vector_database(self, db_path): return data['file_paths'], data['embeddings'] def find_similar_files(self, query, file_paths, embeddings): - query_embedding = self.model.encode([query], convert_to_tensor=True) if embeddings.size(0) == 0: return 'No embeddings available to query against' - scores = util.pytorch_cos_sim(query_embedding, embeddings)[0] - top_results = torch.topk(scores, k=min(10, scores.size(0))) - similar_files = [file_paths[idx] for idx in top_results[1]] + query_embedding = self.model.encode([query], convert_to_tensor=True).cpu().numpy() + embeddings_np = embeddings.cpu().numpy() + index = faiss.IndexFlatL2(embeddings_np.shape[1]) + index.add(embeddings_np) + D, I = index.search(query_embedding, min(10, len(embeddings_np))) + similar_files = [file_paths[idx] for idx in I[0]] expanded_similar_files = set() for file in similar_files: expanded_similar_files.add(file) diff --git a/tests/toolkit/test_vector.py b/tests/toolkit/test_vector.py index 825c1aa8..87e3bed5 100644 --- a/tests/toolkit/test_vector.py +++ b/tests/toolkit/test_vector.py @@ -18,7 +18,7 @@ def vector_toolkit(): def test_query_vector_db_creates_db(temp_dir, vector_toolkit): # Create and load a vector database lazily query = 'print("Hello World")' - result = vector_toolkit.find_simmilar_files_locations(temp_dir.as_posix(), query) + result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) print("Query Result:", result) assert isinstance(result, str) temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) @@ -30,7 +30,7 @@ def test_query_vector_db(temp_dir, vector_toolkit): # Create initial db vector_toolkit.create_vector_db(temp_dir.as_posix()) query = 'print("Hello World")' - result = vector_toolkit.find_simmilar_files_locations(temp_dir.as_posix(), query) + result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) print("Query Result:", result) assert isinstance(result, str) temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) @@ -51,7 +51,7 @@ def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): # Perform query on the lower directory query = 'print("Hello World")' - result = vector_toolkit.find_simmilar_files_locations(lower_dir.as_posix(), query) + result = vector_toolkit.find_similar_files_locations(lower_dir.as_posix(), query) print("Query Result from Lower Directory:", result) # Ensure a DB at the lower level is not created @@ -88,6 +88,6 @@ def create_files(base_path, structure): # Test query query = 'def function_one' - result = vector_toolkit.find_simmilar_files_locations(temp_dir.as_posix(), query) + result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) print("Similar Files Result:", result) assert 'file1.py' in result or 'subdir' in result From 0dfd5702e7ca543f986a5fdd234287dd9bf07314 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Thu, 12 Sep 2024 16:48:19 +1000 Subject: [PATCH 19/23] use faiss and better prompt --- src/goose/toolkit/vector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index 4d77c5d9..d5d4c26f 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -14,6 +14,7 @@ class VectorToolkit(Toolkit): + """Use embeddings for finding related concepts in codebase. """ _model = None From 3da04feedb59a814f374d02a1e100f0b248af5a1 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Thu, 12 Sep 2024 17:10:59 +1000 Subject: [PATCH 20/23] fixing up types --- src/goose/toolkit/vector.py | 52 ++++++++++++++----------------------- 1 file changed, 19 insertions(+), 33 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index d5d4c26f..d880a630 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -2,8 +2,7 @@ import torch import hashlib from goose.toolkit.base import Toolkit, tool -from sentence_transformers import SentenceTransformer, util -from goose.cli.session import SessionNotifier +from sentence_transformers import SentenceTransformer from pathlib import Path from exchange import Message import faiss @@ -19,14 +18,15 @@ class VectorToolkit(Toolkit): _model = None @property - def model(self): + def model(self) -> SentenceTransformer: if self._model is None: os.environ["TOKENIZERS_PARALLELISM"] = "false" self.notifier.status("Preparing local model...") - self._model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', tokenizer_kwargs={"clean_up_tokenization_spaces": True}) + self._model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', + tokenizer_kwargs={"clean_up_tokenization_spaces": True}) return self._model - def get_db_path(self, repo_path): + def get_db_path(self, repo_path:str) -> Path: # Create a hash of the repo path repo_hash = hashlib.md5(repo_path.encode()).hexdigest() return VECTOR_PATH.joinpath(f'code_vectors_{repo_hash}.pt') @@ -53,7 +53,8 @@ def create_vector_db(self, repo_path: str) -> str: @tool def find_similar_files_locations(self, repo_path: str, query: str) -> str: """ - Locate files and locations in a repository that are conceptually related to the query and may hint where to look. + Locate files and locations in a repository that are conceptually related to the query + and may hint where to look. Don't rely on this for exhaustive matches around strings, use ripgrep additionally for searching. Args: @@ -70,26 +71,7 @@ def find_similar_files_locations(self, repo_path: str, query: str) -> str: self.notifier.status("Performing query...") similar_files = self.find_similar_files(query, file_paths, embeddings) return '\n'.join(similar_files) - """ - Locate files and locations in a repository that are potentially semantically related to the query and may hint where to look. - This will not be the only location, but can serve as a starting point. - Note that they will probably not be exact matches, so other tools for text searching can be used as well. - Don't rely on this for exhaustive matches around keywords, it is about concepts. - Args: - repo_path (str): The repository that we will be searching in - query (str): Query string to search for semantically related files or paths. - Returns: - str: List of semantically relevant files and paths to consider. - """ - temp_db_path = self.lookup_db_path(repo_path) - if temp_db_path is None: - temp_db_path = self.create_vector_db(repo_path) - self.notifier.status("Loading embeddings database...") - file_paths, embeddings = self.load_vector_database(temp_db_path) - self.notifier.status("Performing query...") - similar_files = self.find_similar_files(query, file_paths, embeddings) - return '\n'.join(similar_files) def lookup_db_path(self, repo_path: str) -> str: """ @@ -109,7 +91,7 @@ def lookup_db_path(self, repo_path: str) -> str: current_path = current_path.parent return None - def scan_repository(self, repo_path): + def scan_repository(self, repo_path:Path) -> tuple[list[str], list[str]]: repo_path = Path(repo_path).expanduser() file_contents = [] file_paths = [] @@ -119,7 +101,11 @@ def scan_repository(self, repo_path): dirs[:] = [d for d in dirs if not d.startswith('.')] for file in files: file_extension = os.path.splitext(file)[1] - if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.md', '.dart', '.kt', '.swift', '.scala', '.lua', '.pl', '.r', '.m', '.mm', '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', '.erl', '.clj', '.lisp']: + if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', + '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.md', '.dart', + '.kt', '.swift', '.scala', '.lua', '.pl', '.r', '.m', '.mm', + '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', + '.erl', '.clj', '.lisp']: file_path = os.path.join(root, file) file_paths.append(file_path) try: @@ -132,29 +118,29 @@ def scan_repository(self, repo_path): skipped_file_types[file_extension] = True return file_paths, file_contents - def build_vector_database(self, file_contents): + def build_vector_database(self, file_contents:str) -> list[any]: embeddings = self.model.encode(file_contents, convert_to_tensor=True) return embeddings - def save_vector_database(self, file_paths, embeddings, db_path): + def save_vector_database(self, file_paths:list[str], embeddings:list[any], db_path:Path) -> None: torch.save({'file_paths': file_paths, 'embeddings': embeddings}, db_path) - def load_vector_database(self, db_path): + def load_vector_database(self, db_path:Path) -> tuple[list[str], list[any]]: if db_path is not None and os.path.exists(db_path): data = torch.load(db_path, weights_only=True) else: raise ValueError(f"Database path {db_path} does not exist.") return data['file_paths'], data['embeddings'] - def find_similar_files(self, query, file_paths, embeddings): + def find_similar_files(self, query:str, file_paths:list[Path], embeddings:tuple[list[str], list[any]]) -> list[str]: if embeddings.size(0) == 0: return 'No embeddings available to query against' query_embedding = self.model.encode([query], convert_to_tensor=True).cpu().numpy() embeddings_np = embeddings.cpu().numpy() index = faiss.IndexFlatL2(embeddings_np.shape[1]) index.add(embeddings_np) - D, I = index.search(query_embedding, min(10, len(embeddings_np))) - similar_files = [file_paths[idx] for idx in I[0]] + _, i = index.search(query_embedding, min(10, len(embeddings_np))) + similar_files = [file_paths[idx] for idx in i[0]] expanded_similar_files = set() for file in similar_files: expanded_similar_files.add(file) From 12f87be1c4a8c3d9f5d369019831da3701cd3830 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Thu, 12 Sep 2024 17:15:35 +1000 Subject: [PATCH 21/23] more fixing --- src/goose/toolkit/vector.py | 80 +++++++++++++++++++++++++----------- tests/toolkit/test_vector.py | 19 +++++---- 2 files changed, 67 insertions(+), 32 deletions(-) diff --git a/src/goose/toolkit/vector.py b/src/goose/toolkit/vector.py index d880a630..caec29d0 100644 --- a/src/goose/toolkit/vector.py +++ b/src/goose/toolkit/vector.py @@ -13,7 +13,7 @@ class VectorToolkit(Toolkit): - """Use embeddings for finding related concepts in codebase. """ + """Use embeddings for finding related concepts in codebase.""" _model = None @@ -22,14 +22,15 @@ def model(self) -> SentenceTransformer: if self._model is None: os.environ["TOKENIZERS_PARALLELISM"] = "false" self.notifier.status("Preparing local model...") - self._model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', - tokenizer_kwargs={"clean_up_tokenization_spaces": True}) + self._model = SentenceTransformer( + "multi-qa-MiniLM-L6-cos-v1", tokenizer_kwargs={"clean_up_tokenization_spaces": True} + ) return self._model - def get_db_path(self, repo_path:str) -> Path: + def get_db_path(self, repo_path: str) -> Path: # Create a hash of the repo path repo_hash = hashlib.md5(repo_path.encode()).hexdigest() - return VECTOR_PATH.joinpath(f'code_vectors_{repo_hash}.pt') + return VECTOR_PATH.joinpath(f"code_vectors_{repo_hash}.pt") def create_vector_db(self, repo_path: str) -> str: """ @@ -70,8 +71,7 @@ def find_similar_files_locations(self, repo_path: str, query: str) -> str: file_paths, embeddings = self.load_vector_database(temp_db_path) self.notifier.status("Performing query...") similar_files = self.find_similar_files(query, file_paths, embeddings) - return '\n'.join(similar_files) - + return "\n".join(similar_files) def lookup_db_path(self, repo_path: str) -> str: """ @@ -91,50 +91,84 @@ def lookup_db_path(self, repo_path: str) -> str: current_path = current_path.parent return None - def scan_repository(self, repo_path:Path) -> tuple[list[str], list[str]]: + def scan_repository(self, repo_path: Path) -> tuple[list[str], list[str]]: repo_path = Path(repo_path).expanduser() file_contents = [] file_paths = [] skipped_file_types = {} for root, dirs, files in os.walk(repo_path): # Exclude dotfile directories - dirs[:] = [d for d in dirs if not d.startswith('.')] + dirs[:] = [d for d in dirs if not d.startswith(".")] for file in files: file_extension = os.path.splitext(file)[1] - if file_extension in ['.py', '.java', '.js', '.jsx', '.ts', '.tsx', '.cpp', '.c', - '.h', '.hpp', '.rb', '.go', '.rs', '.php', '.md', '.dart', - '.kt', '.swift', '.scala', '.lua', '.pl', '.r', '.m', '.mm', - '.f', '.jl', '.cs', '.vb', '.pas', '.groovy', '.hs', '.elm', - '.erl', '.clj', '.lisp']: + if file_extension in [ + ".py", + ".java", + ".js", + ".jsx", + ".ts", + ".tsx", + ".cpp", + ".c", + ".h", + ".hpp", + ".rb", + ".go", + ".rs", + ".php", + ".md", + ".dart", + ".kt", + ".swift", + ".scala", + ".lua", + ".pl", + ".r", + ".m", + ".mm", + ".f", + ".jl", + ".cs", + ".vb", + ".pas", + ".groovy", + ".hs", + ".elm", + ".erl", + ".clj", + ".lisp", + ]: file_path = os.path.join(root, file) file_paths.append(file_path) try: - with open(file_path, 'r', errors='ignore') as f: + with open(file_path, "r", errors="ignore") as f: content = f.read() file_contents.append(content) except Exception as e: - print(f'Error reading {file_path}: {e}') + print(f"Error reading {file_path}: {e}") else: skipped_file_types[file_extension] = True return file_paths, file_contents - def build_vector_database(self, file_contents:str) -> list[any]: + def build_vector_database(self, file_contents: str) -> list[any]: embeddings = self.model.encode(file_contents, convert_to_tensor=True) return embeddings - def save_vector_database(self, file_paths:list[str], embeddings:list[any], db_path:Path) -> None: - torch.save({'file_paths': file_paths, 'embeddings': embeddings}, db_path) + def save_vector_database(self, file_paths: list[str], embeddings: list[any], db_path: Path) -> None: + torch.save({"file_paths": file_paths, "embeddings": embeddings}, db_path) - def load_vector_database(self, db_path:Path) -> tuple[list[str], list[any]]: + def load_vector_database(self, db_path: Path) -> tuple[list[str], list[any]]: if db_path is not None and os.path.exists(db_path): data = torch.load(db_path, weights_only=True) else: raise ValueError(f"Database path {db_path} does not exist.") - return data['file_paths'], data['embeddings'] + return data["file_paths"], data["embeddings"] - def find_similar_files(self, query:str, file_paths:list[Path], embeddings:tuple[list[str], list[any]]) -> list[str]: + def find_similar_files( + self, query: str, file_paths: list[Path], embeddings: tuple[list[str], list[any]] + ) -> list[str]: if embeddings.size(0) == 0: - return 'No embeddings available to query against' + return "No embeddings available to query against" query_embedding = self.model.encode([query], convert_to_tensor=True).cpu().numpy() embeddings_np = embeddings.cpu().numpy() index = faiss.IndexFlatL2(embeddings_np.shape[1]) diff --git a/tests/toolkit/test_vector.py b/tests/toolkit/test_vector.py index 87e3bed5..6e3935a7 100644 --- a/tests/toolkit/test_vector.py +++ b/tests/toolkit/test_vector.py @@ -6,15 +6,18 @@ import pytest from goose.toolkit.vector import VectorToolkit + @pytest.fixture def temp_dir(): with TemporaryDirectory() as temp_dir: yield Path(temp_dir) + @pytest.fixture def vector_toolkit(): return VectorToolkit(notifier=MagicMock()) + def test_query_vector_db_creates_db(temp_dir, vector_toolkit): # Create and load a vector database lazily query = 'print("Hello World")' @@ -36,7 +39,7 @@ def test_query_vector_db(temp_dir, vector_toolkit): temp_db_path = vector_toolkit.get_db_path(temp_dir.as_posix()) assert os.path.exists(temp_db_path) assert os.path.getsize(temp_db_path) > 0 - assert 'No embeddings available to query against' in result or '\n' in result + assert "No embeddings available to query against" in result or "\n" in result def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): @@ -64,18 +67,16 @@ def test_no_new_db_if_exists_higher(temp_dir, vector_toolkit): def test_find_similar_files_in_repo(temp_dir, vector_toolkit): # Setting up a temporary repository structure file_structure = { - 'file1.py': 'def function_one(): pass\n', - 'file2.py': 'def function_two(): pass\n', - 'subdir': { - 'file3.py': 'class MyClass: pass\n' - } + "file1.py": "def function_one(): pass\n", + "file2.py": "def function_two(): pass\n", + "subdir": {"file3.py": "class MyClass: pass\n"}, } def create_files(base_path, structure): for name, content in structure.items(): path = base_path / name if isinstance(content, str): - with open(path, 'w') as f: + with open(path, "w") as f: f.write(content) else: path.mkdir() @@ -87,7 +88,7 @@ def create_files(base_path, structure): vector_toolkit.create_vector_db(temp_dir.as_posix()) # Test query - query = 'def function_one' + query = "def function_one" result = vector_toolkit.find_similar_files_locations(temp_dir.as_posix(), query) print("Similar Files Result:", result) - assert 'file1.py' in result or 'subdir' in result + assert "file1.py" in result or "subdir" in result From 9bcb9a995b942fcc514b92145dd876f8e7a77912 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Mon, 16 Sep 2024 21:24:22 +1000 Subject: [PATCH 22/23] making vector search stuff optional --- pyproject.toml | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d96c9b8..852030e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,13 +10,19 @@ dependencies = [ "ruamel-yaml>=0.18.6", "ai-exchange>=0.9.0", "click>=8.1.7", - "prompt-toolkit>=3.0.47", + "prompt-toolkit>=3.0.47" +] +author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] +packages = [{ include = "goose", from = "src" }] + + +[project.optional-dependencies] +vector = [ "sentence-transformers>=3.0.1", "torch>=2.4.0", "faiss-cpu" ] -author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] -packages = [{ include = "goose", from = "src" }] + [tool.hatch.build.targets.wheel] packages = ["src/goose"] From 9a488cd251e684c36fc87dc635304fb5e476269b Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Mon, 16 Sep 2024 21:35:51 +1000 Subject: [PATCH 23/23] need to include extras in ci --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0fb1701e..a93b78a5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -24,4 +24,4 @@ jobs: - name: Run tests run: | - uv run pytest tests -m 'not integration' + uv run --all-extras pytest tests -m 'not integration'