Skip to content

Commit

Permalink
Fix minor bugs (#135)
Browse files Browse the repository at this point in the history
* Fix bug

* Fix index bug

* Updaet password field

* Add pre-commit

* Remove upload button

* Refine upload

* Fix pg connection string
  • Loading branch information
moria97 authored Jul 31, 2024
1 parent 43b1c17 commit cd4c0b8
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 24 deletions.
104 changes: 102 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ llama-index-multi-modal-llms-dashscope = "^0.1.2"
llama-index-vector-stores-alibabacloud-opensearch = "^0.1.0"
asyncpg = "^0.29.0"
pgvector = "^0.3.2"
pre-commit = "^3.8.0"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
1 change: 1 addition & 0 deletions pyproject_gpu.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ llama-index-multi-modal-llms-dashscope = "^0.1.2"
llama-index-vector-stores-alibabacloud-opensearch = "^0.1.0"
asyncpg = "^0.29.0"
pgvector = "^0.3.2"
pre-commit = "^3.8.0"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
12 changes: 6 additions & 6 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def _format_rag_response(
return response
elif is_finished:
for i, doc in enumerate(docs):
formatted_file_name = re.sub(
"^[0-9a-z]{32}_", "", doc["metadata"]["file_name"]
)
referenced_docs += (
f'[{i+1}]: {formatted_file_name} Score:{doc["score"]} \n'
)
filename = doc["metadata"].get("file_name", None)
if filename:
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
referenced_docs += (
f'[{i+1}]: {formatted_file_name} Score:{doc["score"]} \n'
)
image_url = doc["metadata"].get("image_url", None)
if image_url:
images += f"""<img src="{image_url}"/>"""
Expand Down
28 changes: 23 additions & 5 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def upload_knowledge(
gr.update(visible=True, value=pd.DataFrame(result)),
gr.update(visible=False),
]
time.sleep(2)
if not all(file.finished is True for file in my_upload_files):
time.sleep(2)

upload_result = "Upload success."
if error_msg:
Expand All @@ -75,6 +76,13 @@ def upload_knowledge(
]


def clear_files():
yield [
gr.update(visible=False, value=pd.DataFrame()),
gr.update(visible=False, value=""),
]


def create_upload_tab() -> Dict[str, Any]:
with gr.Row():
with gr.Column(scale=2):
Expand Down Expand Up @@ -102,7 +110,6 @@ def create_upload_tab() -> Dict[str, Any]:
upload_file = gr.File(
label="Upload a knowledge file.", file_count="multiple"
)
upload_file_btn = gr.Button("Upload", variant="primary")
upload_file_state_df = gr.DataFrame(
label="Upload Status Info", visible=False
)
Expand All @@ -112,12 +119,11 @@ def create_upload_tab() -> Dict[str, Any]:
label="Upload a knowledge directory.",
file_count="directory",
)
upload_dir_btn = gr.Button("Upload", variant="primary")
upload_dir_state_df = gr.DataFrame(
label="Upload Status Info", visible=False
)
upload_dir_state = gr.Textbox(label="Upload Status", visible=False)
upload_file_btn.click(
upload_file.upload(
fn=upload_knowledge,
inputs=[
upload_file,
Expand All @@ -129,7 +135,13 @@ def create_upload_tab() -> Dict[str, Any]:
outputs=[upload_file_state_df, upload_file_state],
api_name="upload_knowledge",
)
upload_dir_btn.click(
upload_file.clear(
fn=clear_files,
inputs=[],
outputs=[upload_file_state_df, upload_file_state],
api_name="clear_file",
)
upload_file_dir.upload(
fn=upload_knowledge,
inputs=[
upload_file_dir,
Expand All @@ -141,6 +153,12 @@ def create_upload_tab() -> Dict[str, Any]:
outputs=[upload_dir_state_df, upload_dir_state],
api_name="upload_knowledge_dir",
)
upload_file_dir.clear(
fn=clear_files,
inputs=[],
outputs=[upload_dir_state_df, upload_dir_state],
api_name="clear_file_dir",
)
return {
chunk_size.elem_id: chunk_size,
chunk_overlap.elem_id: chunk_overlap,
Expand Down
14 changes: 7 additions & 7 deletions src/pai_rag/app/web/tabs/vector_db_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,23 +254,23 @@ def create_vector_db_panel(
fn=connect_vector_func,
inputs=inputs_opensearch,
outputs=con_state_opensearch,
api_name="connect_faiss",
api_name="connect_opensearch",
)
with gr.Column(visible=(vectordb_type == "PostgreSQL")) as postgresql_col:
postgresql_host = gr.Textbox(label="Host", elem_id="postgresql_host")
postgresql_port = gr.Textbox(label="Port", elem_id="postgresql_port")
postgresql_username = gr.Textbox(
label="UserName", elem_id="postgresql_username"
)
postgresql_password = gr.Textbox(
label="Password", type="password", elem_id="postgresql_password"
)
postgresql_database = gr.Textbox(
label="Database", elem_id="postgresql_database"
)
postgresql_table_name = gr.Textbox(
label="TableName", elem_id="postgresql_table_name"
)
postgresql_password = gr.Textbox(
label="Password", elem_id="postgresql_password"
)
postgresql_username = gr.Textbox(
label="UserName", elem_id="postgresql_username"
)
connect_btn_pg = gr.Button("Connect PostgreSQL", variant="primary")
con_state_pg = gr.Textbox(label="Connection Info: ")
inputs_pg = input_elements.union(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from typing import Any, List, NamedTuple, Optional, Type, Union
from urllib.parse import quote_plus

import asyncpg # noqa
import pgvector # noqa
Expand Down Expand Up @@ -259,10 +260,10 @@ def from_params(
"""Return connection string from database parameters."""
conn_str = (
connection_string
or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
or f"postgresql+psycopg2://{user}:{quote_plus(password)}@{host}:{port}/{database}"
)
async_conn_str = async_connection_string or (
f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
f"postgresql+asyncpg://{user}:{quote_plus(password)}@{host}:{port}/{database}"
)
return cls(
connection_string=conn_str,
Expand Down
11 changes: 9 additions & 2 deletions src/pai_rag/modules/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ def __init__(self, config, embed_model, postprocessor):
self.config = config
self.embed_model = embed_model
self.embed_dims = self._get_embed_vec_dim(embed_model)
self.postprocessor = postprocessor
persist_path = config.get("persist_path", DEFAULT_PERSIST_DIR)
folder_name = get_store_persist_directory_name(config, self.embed_dims)
self.persist_path = os.path.join(persist_path, folder_name)
index_entry.register(self.persist_path)

is_empty = not os.path.exists(self.persist_path)
rag_store = RagStore(
config, postprocessor, self.persist_path, is_empty, self.embed_dims
config, self.postprocessor, self.persist_path, is_empty, self.embed_dims
)
self.storage_context = rag_store.get_storage_context()

Expand Down Expand Up @@ -65,7 +66,13 @@ def load_indices(self, storage_context, embed_model):

def reload(self):
if isinstance(self.storage_context.vector_store, FaissVectorStore):
rag_store = RagStore(self.config, self.persist_path, False, self.embed_dims)
rag_store = RagStore(
self.config,
self.postprocessor,
self.persist_path,
False,
self.embed_dims,
)
self.storage_context = rag_store.get_storage_context()

self.vector_index = load_index_from_storage(
Expand Down

0 comments on commit cd4c0b8

Please sign in to comment.