Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 committed Oct 21, 2024
1 parent c880325 commit 86e5461
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
IS_PAI_RAG_CI_TEST: true
PAIRAG_RAG__embedding__source: "DashScope"
PAIRAG_RAG__llm__source: "DashScope"
PAIRAG_RAG__llm__name: "qwen-turbo"
PAIRAG_RAG__llm__name: "qwen-max"
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
IS_PAI_RAG_CI_TEST: true
PAIRAG_RAG__embedding__source: "DashScope"
PAIRAG_RAG__llm__source: "DashScope"
PAIRAG_RAG__llm__name: "qwen-turbo"
PAIRAG_RAG__llm__model: "qwen-max"

- name: Get Cover
uses: orgoro/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ jobs:
IS_PAI_RAG_CI_TEST: true
PAIRAG_RAG__embedding__source: "DashScope"
PAIRAG_RAG__llm__source: "DashScope"
PAIRAG_RAG__llm__name: "qwen-turbo"
PAIRAG_RAG__llm__model: "qwen-max"
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ vector_store.type = "FAISS"
# token = ""
[rag.llm]
source = "DashScope"
model = "qwen-turbo"
model = "qwen-max"

[rag.llm.function_calling_llm]
source = "DashScope"
Expand Down
7 changes: 4 additions & 3 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ def task_status(task_id: str):
async def upload_data(
files: List[UploadFile] = Body(None),
oss_path: str = Form(None),
faiss_path: str = Form(None),
index_name: str = Form(None),
enable_raptor: bool = Form(False),
enable_multimodal: bool = Form(False),
background_tasks: BackgroundTasks = BackgroundTasks(),
):
task_id = uuid.uuid4().hex

logger.info(
f"Upload data task_id: {task_id} index_name: {index_name} enable_multimodal: {enable_multimodal}"
)
if oss_path:
background_tasks.add_task(
rag_service.add_knowledge,
Expand All @@ -161,7 +162,6 @@ async def upload_data(
oss_path=oss_path,
from_oss=True,
index_name=index_name,
faiss_path=faiss_path,
enable_raptor=enable_raptor,
enable_multimodal=enable_multimodal,
)
Expand Down Expand Up @@ -194,6 +194,7 @@ async def upload_data(
oss_path=None,
enable_raptor=enable_raptor,
temp_file_dir=tmpdir,
enable_multimodal=enable_multimodal,
)

return {"task_id": task_id}
Expand Down
27 changes: 17 additions & 10 deletions src/pai_rag/app/web/tabs/settings_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,16 @@ def create_setting_tab() -> Dict[str, Any]:
elem_id="use_mllm",
container=False,
)
with gr.Row(visible=False) as use_mllm_col:
with gr.Row(visible=False, elem_id="use_mllm_col") as use_mllm_col:
mllm = gr.Radio(
["paieas", "dashscope"],
label="LLM Model Source",
elem_id="mllm",
interactive=DEFAULT_IS_INTERACTIVE.lower() != "false",
)
with gr.Row(visible=(mllm == "paieas")) as m_eas_col:
with gr.Row(
visible=(mllm == "paieas"), elem_id="m_eas_col"
) as m_eas_col:
mllm_eas_url = gr.Textbox(
label="EAS Url",
elem_id="mllm_eas_url",
Expand All @@ -216,7 +218,9 @@ def create_setting_tab() -> Dict[str, Any]:
elem_id="mllm_eas_model_name",
interactive=True,
)
with gr.Row(visible=(mllm == "dashscope")) as api_mllm_col:
with gr.Row(
visible=(mllm == "dashscope"), elem_id="api_mllm_col"
) as api_mllm_col:
mllm_api_model_name = gr.Dropdown(
label="LLM Model Name",
elem_id="mllm_api_model_name",
Expand All @@ -231,7 +235,11 @@ def create_setting_tab() -> Dict[str, Any]:
elem_id="use_oss",
container=False,
)
with gr.Row(visible=False) as use_oss_col:
with gr.Row(visible=False, elem_id="use_oss_col") as use_oss_col:
oss_bucket = gr.Textbox(
label="OSS Bucket",
elem_id="oss_bucket",
)
oss_ak = gr.Textbox(
label="Access Key",
elem_id="oss_ak",
Expand All @@ -245,10 +253,7 @@ def create_setting_tab() -> Dict[str, Any]:
oss_endpoint = gr.Textbox(
label="OSS Endpoint",
elem_id="oss_endpoint",
)
oss_bucket = gr.Textbox(
label="OSS Bucket",
elem_id="oss_bucket",
default="oss-cn-hangzhou.aliyuncs.com",
)
use_oss.input(
fn=ev_listeners.change_use_oss,
Expand Down Expand Up @@ -312,8 +317,10 @@ def create_setting_tab() -> Dict[str, Any]:
elems.update(vector_db_components)
elems.update(
{
"use_oss_col": use_oss_col,
"use_mllm_col": use_mllm_col,
m_eas_col.elem_id: m_eas_col,
api_mllm_col.elem_id: api_mllm_col,
use_oss_col.elem_id: use_oss_col,
use_mllm_col.elem_id: use_mllm_col,
}
)
return elems
1 change: 1 addition & 0 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def upload_knowledge(
input_files=[file.name for file in upload_files],
enable_raptor=enable_raptor,
index_name=index_name,
enable_multimodal=enable_multimodal,
)
for file in upload_files:
base_name = os.path.basename(file.name)
Expand Down
4 changes: 3 additions & 1 deletion src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def to_app_config(self):

config["retriever"]["image_similarity_top_k"] = self.image_similarity_top_k

config["retriever"]["need_image"] = self.need_image
config["retriever"]["search_image"] = self.need_image
if self.retrieval_mode == "Hybrid":
config["retriever"]["vector_store_query_mode"] = VectorStoreQueryMode.HYBRID
elif self.retrieval_mode == "Embedding Only":
Expand Down Expand Up @@ -432,6 +432,8 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
"choices": MLLM_MODEL_KEY_DICT.get(self.mllm, []),
"visible": self.mllm.lower() != "paieas",
}
settings["m_eas_col"] = {"visible": self.mllm == "paieas"}
settings["api_mllm_col"] = {"visible": self.mllm == "dashscope"}

settings["use_oss"] = {"value": self.use_oss}
settings["use_oss_col"] = {"visible": self.use_oss}
Expand Down
3 changes: 2 additions & 1 deletion src/pai_rag/app/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pai_rag.app.web.tabs.settings_tab import create_setting_tab
from pai_rag.app.web.tabs.upload_tab import create_upload_tab
from pai_rag.app.web.tabs.chat_tab import create_chat_tab
from pai_rag.app.web.tabs.agent_tab import create_agent_tab
from pai_rag.app.web.tabs.data_analysis_tab import create_data_analysis_tab
from pai_rag.app.web.index_utils import index_related_component_keys

Expand Down Expand Up @@ -84,9 +83,11 @@ def make_homepage():
with gr.Tab("\N{fire} Chat"):
chat_elements = create_chat_tab()
elem_manager.add_elems(chat_elements)
""" hide agent tab
with gr.Tab("\N{rocket} Agent"):
agent_elements = create_agent_tab()
elem_manager.add_elems(agent_elements)
"""
with gr.Tab("\N{bar chart} Data Analysis"):
analysis_elements = create_data_analysis_tab()
elem_manager.add_elems(analysis_elements)
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ vector_store.type = "FAISS"
# token = ""
[rag.llm]
source = "DashScope"
model = "qwen-turbo"
model = "qwen-max"

[rag.multimodal_embedding]
source = "cnclip"
Expand Down
10 changes: 10 additions & 0 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

DEFAULT_EMPTY_RESPONSE_GEN = "Empty Response"
DEFAULT_RAG_INDEX_FILE = "localdata/default_rag_indexes.json"
logger = logging.getLogger(__name__)


def uuid_generator() -> str:
Expand Down Expand Up @@ -89,11 +90,20 @@ def load_knowledge(
enable_raptor=False,
enable_multimodal=False,
):
logger.info(
f"""Loading data:
input_files: {input_files}
index_name: {index_name}
enable_multimodal: {enable_multimodal}
enable_raptor: {enable_raptor}"""
)

session_config = self.config.model_copy()
index_entry = index_manager.get_index_by_name(index_name)
session_config.embedding = index_entry.embedding_config
session_config.index.vector_store = index_entry.vector_store_config
session_config.node_parser.enable_multimodal = enable_multimodal

data_loader = resolve_data_loader(session_config)
data_loader.load_data(
file_path_or_directory=input_files,
Expand Down
8 changes: 6 additions & 2 deletions src/pai_rag/core/rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
PaiBaseEmbeddingConfig,
)
from pai_rag.integrations.index.pai.vector_store_config import PaiVectorIndexConfig
from pai_rag.integrations.llms.pai.llm_config import PaiBaseLlmConfig
from pai_rag.integrations.llms.pai.llm_config import (
DashScopeMultiModalLlmConfig,
PaiBaseLlmConfig,
PaiEasLlmConfig,
)
from pai_rag.integrations.nodeparsers.pai.pai_node_parser import NodeParserConfig
from pai_rag.integrations.postprocessor.pai.pai_postprocessor import (
RerankModelPostProcessorConfig,
Expand Down Expand Up @@ -75,7 +79,7 @@ class RagConfig(BaseModel):
BeforeValidator(validate_case_insensitive),
]
multimodal_llm: Annotated[
Union[PaiBaseLlmConfig.get_subclasses()],
Union[DashScopeMultiModalLlmConfig, PaiEasLlmConfig],
Field(discriminator="source"),
BeforeValidator(validate_case_insensitive),
] | None = None
Expand Down
2 changes: 0 additions & 2 deletions src/pai_rag/core/rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
def resolve(cls: Any, **kwargs):
cls_key = kwargs.__repr__()
if cls_key not in cls_cache:
print("New cls_key", cls_key)

cls_cache[cls_key] = cls(**kwargs)
return cls_cache[cls_key]

Expand Down
1 change: 1 addition & 0 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_config(self):
def reload(self, new_config: Dict):
self.rag_configuration.update(new_config)
self.rag.refresh(self.rag_configuration.get_value())
self.rag_configuration.persist()

def add_knowledge(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/integrations/index/pai/vector_store_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class PaiVectorIndexConfig(BaseModel):
vector_store: Annotated[
Union[BaseVectorStoreConfig.get_subclasses()], Field(discriminator="type")
]
enable_multimodal: bool = False
enable_multimodal: bool = True # default enable multimodal
persist_path: str = DEFAULT_LOCAL_STORAGE_PATH


Expand Down
6 changes: 5 additions & 1 deletion src/pai_rag/integrations/llms/pai/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class DashScopeLlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.dashscope] = SupportedLlmType.dashscope
api_key: str | None = None
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
model: str = "qwen-turbo"
model: str = "qwen-max"


class OpenAILlmConfig(PaiBaseLlmConfig):
Expand All @@ -160,6 +160,10 @@ class PaiEasLlmConfig(PaiBaseLlmConfig):
model: str = "default"


class DashScopeMultiModalLlmConfig(DashScopeLlmConfig):
model: str = "qwen-vl-max"


SupporttedLlmClsMap = {cls.get_type(): cls for cls in PaiBaseLlmConfig.get_subclasses()}


Expand Down

0 comments on commit 86e5461

Please sign in to comment.