From 5318685d6402fc18278cfa21a274e7f88723bbd7 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 10 Aug 2023 15:14:09 +0800 Subject: [PATCH] Add generate_questions in insert pipeline Signed-off-by: Jael Gu --- config.py | 2 +- requirements.txt | 3 ++- src_towhee/operations.py | 2 +- src_towhee/pipelines/__init__.py | 17 +++++++++++++---- .../pipelines/insert/generate_questions.py | 6 +++--- src_towhee/pipelines/search/prompts.py | 1 + src_towhee/pipelines/utils.py | 2 -- 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/config.py b/config.py index c01a496..bbf119f 100644 --- a/config.py +++ b/config.py @@ -73,7 +73,7 @@ 'password': os.getenv('MILVUS_PASSWORD', ''), 'secure': True if os.getenv('MILVUS_SECURE', 'False').lower() == 'true' else False }, - 'top_k': 1, + 'top_k': 5, 'threshold': 0, 'index_params': { 'metric_type': 'IP', diff --git a/requirements.txt b/requirements.txt index e825d3d..d140115 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,11 @@ unstructured pexpect pdf2image SQLAlchemy>=2.0.15 -# psycopg2-binary +psycopg2-binary openai gradio>=3.30.0 fastapi uvicorn towhee>=1.1.0 pymilvus +elasticsearch diff --git a/src_towhee/operations.py b/src_towhee/operations.py index fac342f..88b2495 100644 --- a/src_towhee/operations.py +++ b/src_towhee/operations.py @@ -109,7 +109,7 @@ def clear_history(project, session_id): # if __name__ == '__main__': -# project = 'akcio' +# project = 'akcio_test' # data_src = 'https://towhee.io' # session_id = 'test000' # question0 = 'What is your code name?' diff --git a/src_towhee/pipelines/__init__.py b/src_towhee/pipelines/__init__.py index 1ec1d08..af5e431 100644 --- a/src_towhee/pipelines/__init__.py +++ b/src_towhee/pipelines/__init__.py @@ -3,7 +3,7 @@ from typing import Any, Dict from pymilvus import Collection, connections -from towhee import AutoConfig, AutoPipes +from towhee import AutoConfig sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) @@ -11,7 +11,8 @@ USE_SCALAR, LLM_OPTION, TEXTENCODER_CONFIG, CHAT_CONFIG, VECTORDB_CONFIG, SCALARDB_CONFIG, - RERANK_CONFIG, QUERY_MODE, INSERT_MODE + RERANK_CONFIG, QUERY_MODE, INSERT_MODE, + DATAPARSER_CONFIG ) from src_towhee.base import BasePipelines # pylint: disable=C0413 from src_towhee.pipelines.search import build_search_pipeline # pylint: disable=C0413 @@ -29,7 +30,8 @@ def __init__(self, scalardb_config: Dict = SCALARDB_CONFIG, rerank_config: Dict = RERANK_CONFIG, query_mode: str = QUERY_MODE, - insert_mode: str = INSERT_MODE + insert_mode: str = INSERT_MODE, + chunk_size: int = DATAPARSER_CONFIG['chunk_size'] ): # pylint: disable=W0102 self.use_scalar = use_scalar self.llm_src = llm_src @@ -39,6 +41,7 @@ def __init__(self, self.chat_config = chat_config self.textencoder_config = textencoder_config self.rerank_config = rerank_config + self.chunk_size = chunk_size self.milvus_uri = vectordb_config['connection_args']['uri'] self.milvus_host = self.milvus_uri.split('https://')[1].split(':')[0] @@ -125,7 +128,13 @@ def search_config(self): @property def insert_config(self): - insert_config = AutoConfig.load_config('osschat-insert') + insert_config = AutoConfig.load_config( + 'osschat-insert', + llm_src=self.llm_src, + **self.chat_config[self.llm_src] + ) + # Configure chunk size + insert_config.chunk_size = self.chunk_size # Configure embedding insert_config.embedding_model = self.textencoder_config['model'] diff --git a/src_towhee/pipelines/insert/generate_questions.py b/src_towhee/pipelines/insert/generate_questions.py index e0787f5..5305fd4 100644 --- a/src_towhee/pipelines/insert/generate_questions.py +++ b/src_towhee/pipelines/insert/generate_questions.py @@ -19,12 +19,12 @@ def build_prompt(doc, project: str = '', num: int = 10): query_prompt = QUERY_TEMP.format(project=project, doc=doc) return [({'system': sys_prompt}), ({'question': query_prompt})] -def parse_output(res): +def parse_output(doc, res): questions = [] for q in res.split('\n'): q = ('. ').join(q.split('. ')[1:]) questions.append(q) - return questions + return [(doc, q) for q in questions] def custom_pipeline(config): embedding_op = get_embedding_op(config) @@ -37,7 +37,7 @@ def custom_pipeline(config): )) .map(('chunk', 'project'), 'prompt', build_prompt) .map('prompt', 'gen_res', llm_op) - .flat_map('gen_res', 'gen_question', parse_output) + .flat_map(('chunk', 'gen_res'), ('chunk', 'gen_question'), parse_output) .map('gen_question', 'embedding', embedding_op) ) diff --git a/src_towhee/pipelines/search/prompts.py b/src_towhee/pipelines/search/prompts.py index 4718c0b..2370551 100644 --- a/src_towhee/pipelines/search/prompts.py +++ b/src_towhee/pipelines/search/prompts.py @@ -14,6 +14,7 @@ ''' QUERY_PROMPT = '''Use previous conversation history (if there is any) and the following pieces of context to answer the question at the end. +Don't mention that you got this answer from context. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} diff --git a/src_towhee/pipelines/utils.py b/src_towhee/pipelines/utils.py index 0d1cdf0..8399d68 100644 --- a/src_towhee/pipelines/utils.py +++ b/src_towhee/pipelines/utils.py @@ -2,8 +2,6 @@ def get_llm_op(config): - if config.customize_llm: - return config.customize_llm if config.llm_src.lower() == 'openai': return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key, **config.llm_kwargs) if config.llm_src.lower() == 'dolly':