From e04950a672479ff36d033f0e914003fefd1e99c5 Mon Sep 17 00:00:00 2001 From: makrianast Date: Wed, 6 Nov 2024 01:23:58 +0200 Subject: [PATCH 1/3] first commit --- src/harmony/matching/default_matcher.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index 1f8ada7..cf81be8 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -47,10 +47,19 @@ model = SentenceTransformer(sentence_transformer_path) -def convert_texts_to_vector(texts: List) -> ndarray: - embeddings = model.encode(sentences=texts, convert_to_numpy=True) +def convert_texts_to_vector(texts: List, batch_size=50) -> ndarray: + embeddings = [] + + # Process texts in batches + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + batch_embeddings = model.encode(sentences=batch, convert_to_numpy=True) + embeddings.append(batch_embeddings) + + # Concatenate all batch embeddings into a single NumPy array + return np.concatenate(embeddings, axis=0) + - return embeddings def match_instruments( @@ -59,12 +68,13 @@ def match_instruments( mhc_questions: List = [], mhc_all_metadatas: List = [], mhc_embeddings: np.ndarray = np.zeros((0, 0)), - texts_cached_vectors: dict[str, List[float]] = {}, + texts_cached_vectors: dict[str, List[float]] = {},batch_size: int = 50, + ) -> tuple: return match_instruments_with_function( instruments=instruments, query=query, - vectorisation_function=convert_texts_to_vector, + vectorisation_function=lambda texts: convert_texts_to_vector(texts, batch_size=batch_size), mhc_questions=mhc_questions, mhc_all_metadatas=mhc_all_metadatas, mhc_embeddings=mhc_embeddings, From 9f16b41f402f101fc9aeb6930261430ad2c1265b Mon Sep 17 00:00:00 2001 From: makrianast Date: Tue, 12 Nov 2024 12:35:34 +0200 Subject: [PATCH 2/3] second commit --- src/harmony/matching/default_matcher.py | 17 +++++-- tests/test_batch.py | 62 +++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 tests/test_batch.py diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index cf81be8..2a2d3a8 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -47,34 +47,43 @@ model = SentenceTransformer(sentence_transformer_path) -def convert_texts_to_vector(texts: List, batch_size=50) -> ndarray: +def convert_texts_to_vector(texts: List, batch_size=50, max_batches=2000) -> ndarray: + if batch_size==0: + embeddings = model.encode(sentences=texts, convert_to_numpy=True) + + return embeddings + embeddings = [] + batch_count = 0 + # Process texts in batches for i in range(0, len(texts), batch_size): + if batch_count >= max_batches: + break batch = texts[i:i + batch_size] batch_embeddings = model.encode(sentences=batch, convert_to_numpy=True) embeddings.append(batch_embeddings) + batch_count += 1 # Concatenate all batch embeddings into a single NumPy array return np.concatenate(embeddings, axis=0) - def match_instruments( instruments: List[Instrument], query: str = None, mhc_questions: List = [], mhc_all_metadatas: List = [], mhc_embeddings: np.ndarray = np.zeros((0, 0)), - texts_cached_vectors: dict[str, List[float]] = {},batch_size: int = 50, + texts_cached_vectors: dict[str, List[float]] = {},batch_size: int = 50, max_batches: int =2000, ) -> tuple: return match_instruments_with_function( instruments=instruments, query=query, - vectorisation_function=lambda texts: convert_texts_to_vector(texts, batch_size=batch_size), + vectorisation_function=lambda texts: convert_texts_to_vector(texts, batch_size=batch_size, max_batches=max_batches), mhc_questions=mhc_questions, mhc_all_metadatas=mhc_all_metadatas, mhc_embeddings=mhc_embeddings, diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..15c9902 --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,62 @@ +''' +MIT License + +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). +Project: Harmony (https://harmonydata.ac.uk) +Maintainer: Thomas Wood (https://fastdatascience.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import sys +import unittest +import numpy + +from src.harmony.matching.default_matcher import convert_texts_to_vector + + +class createModel: + def encode(self, sentences, convert_to_numpy=True): + # Generate a dummy embedding with 768 dimensions for each sentence + return numpy.array([[1] * 768] * len(sentences)) + + + +model = createModel() + +class TestBatching(unittest.TestCase): + def test_convert_texts_to_vector_with_batching(self): + # Create a list of 10 dummy texts + texts = ["text" + str(i) for i in range(10)] + + + batch_size = 5 + max_batches = 2 + embeddings = convert_texts_to_vector(texts, batch_size=batch_size, max_batches=max_batches) + + + self.assertEqual(embeddings.shape[0], 10) + + + self.assertEqual(embeddings.shape[1], 384) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From b10ba08c7cff3c4c3ea53cbe43afd3fa7ba214ef Mon Sep 17 00:00:00 2001 From: makrianast Date: Mon, 18 Nov 2024 13:25:27 +0200 Subject: [PATCH 3/3] third commit --- src/harmony/matching/default_matcher.py | 28 ++++++++++++------------- tests/test_batch.py | 5 +++-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index 2a2d3a8..902f931 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -28,15 +28,14 @@ from typing import List import numpy as np -from numpy import ndarray -from sentence_transformers import SentenceTransformer - from harmony import match_instruments_with_function from harmony.schemas.requests.text import Instrument +from numpy import ndarray +from sentence_transformers import SentenceTransformer if ( - os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) is not None - and os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) != "" + os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) is not None + and os.environ.get("HARMONY_SENTENCE_TRANSFORMER_PATH", None) != "" ): sentence_transformer_path = os.environ["HARMONY_SENTENCE_TRANSFORMER_PATH"] else: @@ -48,7 +47,7 @@ def convert_texts_to_vector(texts: List, batch_size=50, max_batches=2000) -> ndarray: - if batch_size==0: + if batch_size == 0: embeddings = model.encode(sentences=texts, convert_to_numpy=True) return embeddings @@ -56,7 +55,6 @@ def convert_texts_to_vector(texts: List, batch_size=50, max_batches=2000) -> nda embeddings = [] batch_count = 0 - # Process texts in batches for i in range(0, len(texts), batch_size): if batch_count >= max_batches: @@ -70,20 +68,20 @@ def convert_texts_to_vector(texts: List, batch_size=50, max_batches=2000) -> nda return np.concatenate(embeddings, axis=0) - def match_instruments( - instruments: List[Instrument], - query: str = None, - mhc_questions: List = [], - mhc_all_metadatas: List = [], - mhc_embeddings: np.ndarray = np.zeros((0, 0)), - texts_cached_vectors: dict[str, List[float]] = {},batch_size: int = 50, max_batches: int =2000, + instruments: List[Instrument], + query: str = None, + mhc_questions: List = [], + mhc_all_metadatas: List = [], + mhc_embeddings: np.ndarray = np.zeros((0, 0)), + texts_cached_vectors: dict[str, List[float]] = {}, batch_size: int = 50, max_batches: int = 2000, ) -> tuple: return match_instruments_with_function( instruments=instruments, query=query, - vectorisation_function=lambda texts: convert_texts_to_vector(texts, batch_size=batch_size, max_batches=max_batches), + vectorisation_function=lambda texts: convert_texts_to_vector(texts, batch_size=batch_size, + max_batches=max_batches), mhc_questions=mhc_questions, mhc_all_metadatas=mhc_all_metadatas, mhc_embeddings=mhc_embeddings, diff --git a/tests/test_batch.py b/tests/test_batch.py index 15c9902..9795c44 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -29,8 +29,9 @@ import unittest import numpy -from src.harmony.matching.default_matcher import convert_texts_to_vector +sys.path.append("../src") +from harmony.matching.default_matcher import convert_texts_to_vector class createModel: def encode(self, sentences, convert_to_numpy=True): @@ -59,4 +60,4 @@ def test_convert_texts_to_vector_with_batching(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()