Skip to content

Commit

Permalink
Update retrieval result (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored May 29, 2024
1 parent 387cc69 commit aff0197
Show file tree
Hide file tree
Showing 12 changed files with 16 additions and 107 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
git clone [email protected]:aigc-apps/PAI-RAG.git
```

注:如果需要调用Open AI,需要使用新加坡开发机器,不能连通弹内Gitlab环境,需要手动将代码上传到新加坡机器

### Step2: 配置环境

本项目使用poetry进行管理,建议在安装环境之前先创建一个空环境。为了确保环境一致性并避免因Python版本差异造成的问题,我们指定Python版本为3.10。
Expand All @@ -36,7 +34,7 @@ poetry install
pai_rag run [--host HOST] [--port PORT] [--config CONFIG_FILE]
```

现在你可以使用命令行向服务侧发送API请求,或者直接打开http://HOST:PORT。
现在你可以使用命令行向服务侧发送API请求,或者直接打开http://localhost:8000

1.

Expand Down
6 changes: 5 additions & 1 deletion src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ class RagResponse(BaseModel):
class LlmResponse(BaseModel):
answer: str

class ContextDoc(BaseModel):
text: str
score: float
metadata: Dict

class RetrievalResponse(BaseModel):
answer: str
docs: List[ContextDoc]


class KnowledgeInput(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def query_vector(self, text: str):
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id

formatted_text = "\n\n".join([f"""[Doc {i+1}] [score: {doc["score"]}]\n{doc["text"]}""" for i,doc in enumerate(response["docs"])])
response["answer"] = formatted_text
return response

def add_knowledge(self, file_dir: str, enable_qa_extraction: bool):
Expand Down
8 changes: 2 additions & 6 deletions src/pai_rag/app/web/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#### \N{fire} Platform: [PAI](https://help.aliyun.com/zh/pai) / [PAI-EAS](https://www.aliyun.com/product/bigdata/learn/eas) / [PAI-DSW](https://pai.console.aliyun.com/notebook)   \N{rocket} Supported VectorStores: [Hologres](https://www.aliyun.com/product/bigdata/hologram) / [ElasticSearch](https://www.aliyun.com/product/bigdata/elasticsearch) / [AnalyticDB](https://www.aliyun.com/product/apsaradb/gpdb) / [FAISS](https://python.langchain.com/docs/integrations/vectorstores/faiss)
#### \N{fire} <a href='/docs'>API Docs</a> &emsp; \N{rocket} <a href={EAS_TRACE_ENDPOINT}>View Tracing</a> &emsp; \N{fire} 欢迎加入【PAI】RAG答疑群 27370042974
#### \N{fire} <a href='/docs'>API Docs</a> &emsp; \N{rocket} \N{fire} 欢迎加入【PAI】RAG答疑群 27370042974
"""

css_style = """
Expand Down Expand Up @@ -145,11 +145,7 @@ def respond(input_elements: List[Any]):

def create_ui():
with gr.Blocks(css=css_style) as homepage:
gr.Markdown(
value=welcome_message_markdown.format(
EAS_TRACE_ENDPOINT=environ.get("EAS_ARIZE_PHOENIX_URL", "")
)
)
gr.Markdown(value=welcome_message_markdown)

with gr.Tab("\N{rocket} Settings"):
with gr.Row():
Expand Down
11 changes: 4 additions & 7 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
RetrievalQuery,
RagResponse,
LlmResponse,
ContextDoc,
RetrievalResponse,
)

Expand Down Expand Up @@ -64,15 +65,11 @@ async def aquery_vectordb(self, query: RetrievalQuery) -> RetrievalResponse:
session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
node_results = await self.retriever.aretrieve(query.question)
text_results = [
score_node.node.get_content().replace("\n", " ")

docs = [ContextDoc(text = score_node.node.get_content(), metadata=score_node.node.metadata, score=score_node.score)
for score_node in node_results
]
formatted_results = [
f"**Doc [{i}]**: {text} \n" for i, text in enumerate(text_results)
]
result = "\n".join(formatted_results)
return RetrievalResponse(answer=result)
return RetrievalResponse(docs=docs)

async def aquery(self, query: RagQuery) -> RagResponse:
"""Query answer from RAG App asynchronously.
Expand Down
1 change: 0 additions & 1 deletion src/pai_rag/integrations/llms/paieas/README.md

This file was deleted.

3 changes: 0 additions & 3 deletions src/pai_rag/integrations/llms/paieas/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(

@classmethod
def class_name(cls) -> str:
return "PaiEasLLM"
return "PaiEAS"

@property
def metadata(self) -> LLMMetadata:
Expand Down
7 changes: 0 additions & 7 deletions src/pai_rag/integrations/llms/paieas/poetry.lock

This file was deleted.

21 changes: 0 additions & 21 deletions src/pai_rag/integrations/llms/paieas/pyproject.toml

This file was deleted.

55 changes: 0 additions & 55 deletions src/pai_rag/integrations/llms/paieas/tests/test_pai_eas_llm.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/pai_rag/modules/llm/llm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from llama_index.llms.openai import OpenAI
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.dashscope import DashScope
from pai_rag.integrations.llms.paieas import PaiEAS
from pai_rag.integrations.llms.paieas.base import PaiEAS
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG

Expand Down

0 comments on commit aff0197

Please sign in to comment.