Skip to content

Commit

Permalink
Add generate_questions in insert pipeline
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu committed Aug 10, 2023
1 parent f6e9c50 commit 5318685
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
'password': os.getenv('MILVUS_PASSWORD', ''),
'secure': True if os.getenv('MILVUS_SECURE', 'False').lower() == 'true' else False
},
'top_k': 1,
'top_k': 5,
'threshold': 0,
'index_params': {
'metric_type': 'IP',
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ unstructured
pexpect
pdf2image
SQLAlchemy>=2.0.15
# psycopg2-binary
psycopg2-binary
openai
gradio>=3.30.0
fastapi
uvicorn
towhee>=1.1.0
pymilvus
elasticsearch
2 changes: 1 addition & 1 deletion src_towhee/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def clear_history(project, session_id):


# if __name__ == '__main__':
# project = 'akcio'
# project = 'akcio_test'
# data_src = 'https://towhee.io'
# session_id = 'test000'
# question0 = 'What is your code name?'
Expand Down
17 changes: 13 additions & 4 deletions src_towhee/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from typing import Any, Dict

from pymilvus import Collection, connections
from towhee import AutoConfig, AutoPipes
from towhee import AutoConfig

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from config import ( # pylint: disable=C0413
USE_SCALAR, LLM_OPTION,
TEXTENCODER_CONFIG, CHAT_CONFIG,
VECTORDB_CONFIG, SCALARDB_CONFIG,
RERANK_CONFIG, QUERY_MODE, INSERT_MODE
RERANK_CONFIG, QUERY_MODE, INSERT_MODE,
DATAPARSER_CONFIG
)
from src_towhee.base import BasePipelines # pylint: disable=C0413
from src_towhee.pipelines.search import build_search_pipeline # pylint: disable=C0413
Expand All @@ -29,7 +30,8 @@ def __init__(self,
scalardb_config: Dict = SCALARDB_CONFIG,
rerank_config: Dict = RERANK_CONFIG,
query_mode: str = QUERY_MODE,
insert_mode: str = INSERT_MODE
insert_mode: str = INSERT_MODE,
chunk_size: int = DATAPARSER_CONFIG['chunk_size']
): # pylint: disable=W0102
self.use_scalar = use_scalar
self.llm_src = llm_src
Expand All @@ -39,6 +41,7 @@ def __init__(self,
self.chat_config = chat_config
self.textencoder_config = textencoder_config
self.rerank_config = rerank_config
self.chunk_size = chunk_size

self.milvus_uri = vectordb_config['connection_args']['uri']
self.milvus_host = self.milvus_uri.split('https://')[1].split(':')[0]
Expand Down Expand Up @@ -125,7 +128,13 @@ def search_config(self):

@property
def insert_config(self):
insert_config = AutoConfig.load_config('osschat-insert')
insert_config = AutoConfig.load_config(
'osschat-insert',
llm_src=self.llm_src,
**self.chat_config[self.llm_src]
)
# Configure chunk size
insert_config.chunk_size = self.chunk_size

# Configure embedding
insert_config.embedding_model = self.textencoder_config['model']
Expand Down
6 changes: 3 additions & 3 deletions src_towhee/pipelines/insert/generate_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def build_prompt(doc, project: str = '', num: int = 10):
query_prompt = QUERY_TEMP.format(project=project, doc=doc)
return [({'system': sys_prompt}), ({'question': query_prompt})]

def parse_output(res):
def parse_output(doc, res):
questions = []
for q in res.split('\n'):
q = ('. ').join(q.split('. ')[1:])
questions.append(q)
return questions
return [(doc, q) for q in questions]

def custom_pipeline(config):
embedding_op = get_embedding_op(config)
Expand All @@ -37,7 +37,7 @@ def custom_pipeline(config):
))
.map(('chunk', 'project'), 'prompt', build_prompt)
.map('prompt', 'gen_res', llm_op)
.flat_map('gen_res', 'gen_question', parse_output)
.flat_map(('chunk', 'gen_res'), ('chunk', 'gen_question'), parse_output)
.map('gen_question', 'embedding', embedding_op)
)

Expand Down
1 change: 1 addition & 0 deletions src_towhee/pipelines/search/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
'''

QUERY_PROMPT = '''Use previous conversation history (if there is any) and the following pieces of context to answer the question at the end.
Don't mention that you got this answer from context.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Expand Down
2 changes: 0 additions & 2 deletions src_towhee/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@


def get_llm_op(config):
if config.customize_llm:
return config.customize_llm
if config.llm_src.lower() == 'openai':
return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key, **config.llm_kwargs)
if config.llm_src.lower() == 'dolly':
Expand Down

0 comments on commit 5318685

Please sign in to comment.