From ec886afb4d6e5fa70929dd2ef86a80520325ddb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Tue, 4 Jun 2024 11:02:25 +0800 Subject: [PATCH] fix webui --- src/pai_rag/app/web/ui.py | 44 +++++++++++++------------------- src/pai_rag/config/settings.toml | 1 - 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/pai_rag/app/web/ui.py b/src/pai_rag/app/web/ui.py index 0a4aede4..b0e96d89 100644 --- a/src/pai_rag/app/web/ui.py +++ b/src/pai_rag/app/web/ui.py @@ -35,6 +35,7 @@ """ DEFAULT_EMBED_SIZE = 1536 +DEFAULT_HF_EMBED_MODEL = "bge-small-zh-v1.5" embedding_dim_dict = { "bge-small-zh-v1.5": 1024, @@ -163,14 +164,6 @@ def create_ui(): elem_id="embed_model", visible=(view_model.embed_source == "HuggingFace"), ) - embed_api_key = gr.Textbox( - visible=view_model.embed_source != "HuggingFace", - label="Embedding API Key", - value=view_model.embed_api_key, - type="password", - interactive=True, - elem_id="embed_api_key", - ) embed_dim = gr.Textbox( label="Embedding Dimension", value=embedding_dim_dict.get( @@ -181,16 +174,24 @@ def create_ui(): def change_emb_source(source): view_model.embed_source = source + view_model.embed_model = ( + DEFAULT_HF_EMBED_MODEL + if source == "HuggingFace" + else source + ) + _embed_dim = ( + embedding_dim_dict.get( + view_model.embed_model, DEFAULT_EMBED_SIZE + ) + if source == "HuggingFace" + else DEFAULT_EMBED_SIZE + ) return { embed_model: gr.update( - visible=(source == "HuggingFace") - ), - embed_dim: embedding_dim_dict.get( - view_model.embed_model, DEFAULT_EMBED_SIZE - ), - embed_api_key: gr.update( - visible=(source != "HuggingFace") + visible=(source == "HuggingFace"), + value=view_model.embed_model, ), + embed_dim: _embed_dim, } def change_emb_model(model): @@ -199,20 +200,17 @@ def change_emb_model(model): embed_dim: embedding_dim_dict.get( view_model.embed_model, DEFAULT_EMBED_SIZE ), - embed_api_key: gr.update( - visible=(view_model.embed_source != "HuggingFace") - ), } embed_source.change( fn=change_emb_source, inputs=embed_source, - outputs=[embed_model, embed_dim, embed_api_key], + outputs=[embed_model, embed_dim], ) embed_model.change( fn=change_emb_model, inputs=embed_model, - outputs=[embed_dim, embed_api_key], + outputs=[embed_dim], ) with gr.Column(): @@ -242,11 +240,6 @@ def change_emb_model(model): with gr.Column( visible=(view_model.llm != "PaiEas") ) as api_llm_col: - llm_api_key = gr.Textbox( - label="API Key", - value=view_model.llm_api_key, - elem_id="llm_api_key", - ) llm_api_model_name = gr.Dropdown( llm_model_key_dict.get(view_model.llm, []), label="LLM Model Name", @@ -297,7 +290,6 @@ def change_llm_src(value): embed_source, embed_model, embed_dim, - llm_api_key, llm_api_model_name, }, connect_vector_func=connect_vector_db, diff --git a/src/pai_rag/config/settings.toml b/src/pai_rag/config/settings.toml index 65d279a3..d610039e 100644 --- a/src/pai_rag/config/settings.toml +++ b/src/pai_rag/config/settings.toml @@ -21,7 +21,6 @@ type = "SimpleDirectoryReader" [rag.embedding] source = "DashScope" -model_name = "qwen-turbo" [rag.evaluation] retrieval = ["mrr", "hit_rate"]