Skip to content

Commit

Permalink
Fix dependency and ci pipeline(#40)
Browse files Browse the repository at this point in the history
* Update retrieval result

* Update deps

* Update lint

* Update yaml

* Add lint

* Update lint

* Update test case

* Add job name

* Add status badge

* Add coverage task

* Fix coverage

* Fix yaml

* Add workflow permission

* Add test coverage

* Load knowledge

* Update test threshold

* Update coverage threshold

* Update name
  • Loading branch information
moria97 authored May 30, 2024
1 parent fbb6bb5 commit 1e0e488
Show file tree
Hide file tree
Showing 22 changed files with 367 additions and 606 deletions.
40 changes: 33 additions & 7 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ on:
branches:
- main
- feature
- 'releases/**'
- "releases/**"

permissions:
contents: read
pull-requests: write

jobs:
build:
name: Build and Test
runs-on: ubuntu-latest

steps:
Expand All @@ -19,17 +24,38 @@ jobs:
uses: actions/setup-python@v5
with:
# Semantic version range syntax or exact version of a Python version
python-version: '3.10'
python-version: "3.10"
# Optional - x64 or x86 architecture, defaults to x64
architecture: 'x64'
- name: Install dependencies
architecture: "x64"

- name: Install Dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install poetry
poetry install
env:
POETRY_VIRTUALENVS_CREATE: false

- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install pre-commit
shell: bash
run: poetry run pip install pre-commit

- name: Run Linter
shell: bash
run: poetry run make lint

- name: Run Tests
run: |
make coveragetest
env:
DASHSCOPE_API_KEY: ${{ secrets.TESTDASHSCOPEKEY }}
IS_PAI_RAG_CI_TEST: true

- name: Get Cover
uses: orgoro/[email protected]
with:
coverageFile: localdata/test_output/coverage_report.xml
token: ${{ secrets.GITHUB_TOKEN }}
thresholdAll: 0.5 # Total coverage threshold
#thresholdNew: 0.9 # New files coverage threshold
#thresholdModified: 0.9 # Modified files coverage threshold
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
test: ## Run tests via pytest.
pytest tests

coveragetest: ## Tests with coverage report
pytest --cov-report xml:localdata/test_output/coverage_report.xml --cov=pai_rag tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/

Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# PAI-RAG: An easy-to-use framework for modular RAG.

[![PAI-RAG CI](https://github.com/aigc-apps/PAI-RAG/actions/workflows/main.yml/badge.svg)](https://github.com/aigc-apps/PAI-RAG/actions/workflows/main.yml)

## Get Started

### Step1: Clone Repo

```bash
git clone [email protected]:aigc-apps/PAI-RAG.git
```
Expand Down Expand Up @@ -45,6 +47,7 @@ curl -X 'POST' http://127.0.0.1:8000/service/query -H "Content-Type: application
```

- **多轮对话请求**

```bash
curl -X 'POST' http://127.0.0.1:8000/service/query -H "Content-Type: application/json" -d '{"question":"一键助眠是什么?"}'

Expand All @@ -59,11 +62,11 @@ curl -X 'POST' http://127.0.0.1:8000/service/query -H "Content-Type: application
```

- **Agent简单对话**

```bash
curl -X 'POST' http://127.0.0.1:8000/service/query/agent -H "Content-Type: application/json" -d '{"question":"最近互联网公司有发生什么大新闻吗?"}'
```


2. Retrieval Batch评估

```bash
Expand Down
35 changes: 0 additions & 35 deletions main.yaml

This file was deleted.

714 changes: 205 additions & 509 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ llama-index-vector-stores-milvus = "^0.1.10"
gradio = "3.41.0"
faiss-cpu = "^1.8.0"
hologres-vector = "^0.0.9"
arize-phoenix = "^3.21.0"
llama-index-callbacks-arize-phoenix = "^0.1.4"
dynaconf = "^3.2.5"
docx2txt = "^0.8"
click = "^8.1.7"
Expand Down Expand Up @@ -61,6 +59,7 @@ llama-index-tools-duckduckgo = "^0.1.1"
openinference-instrumentation = "^0.1.7"
llama-index-llms-huggingface = "^0.2.0"
pytest-asyncio = "^0.23.7"
pytest-cov = "^5.0.0"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
4 changes: 3 additions & 1 deletion src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ class RagResponse(BaseModel):
class LlmResponse(BaseModel):
answer: str


class ContextDoc(BaseModel):
text: str
score: float
metadata: Dict


class RetrievalResponse(BaseModel):
docs: List[ContextDoc]


class KnowledgeInput(BaseModel):
class DataInput(BaseModel):
file_path: str
enable_qa_extraction: bool = False
10 changes: 6 additions & 4 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
RetrievalQuery,
RagResponse,
LlmResponse,
KnowledgeInput,
DataInput,
)

router = APIRouter()
Expand All @@ -25,20 +25,22 @@ async def aquery_llm(query: LlmQuery) -> LlmResponse:

@router.post("/query/retrieval")
async def aquery_retrieval(query: RetrievalQuery):
return await rag_service.aquery_vectordb(query)
return await rag_service.aquery_retrieval(query)


@router.post("/query/agent")
async def aquery_agent(query: LlmQuery) -> LlmResponse:
return await rag_service.aquery_agent(query)


@router.patch("/config")
async def aupdate(new_config: Any = Body(None)):
rag_service.reload(new_config)
return {"msg": "Update RAG configuration successfully."}


@router.post("/knowledge")
async def load_knowledge(input: KnowledgeInput):
@router.post("/data")
async def load_data(input: DataInput):
await rag_service.add_knowledge(
file_dir=input.file_path, enable_qa_extraction=input.enable_qa_extraction
)
Expand Down
13 changes: 9 additions & 4 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def config_url(self):
return f"{self.endpoint}service/config"

@property
def load_knowledge_url(self):
return f"{self.endpoint}service/knowledge"
def load_data_url(self):
return f"{self.endpoint}service/data"

def query(self, text: str, session_id: str = None):
q = dict(question=text)
Expand Down Expand Up @@ -76,13 +76,18 @@ def query_vector(self, text: str):
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id
formatted_text = "\n\n".join([f"""[Doc {i+1}] [score: {doc["score"]}]\n{doc["text"]}""" for i,doc in enumerate(response["docs"])])
formatted_text = "\n\n".join(
[
f"""[Doc {i+1}] [score: {doc["score"]}]\n{doc["text"]}"""
for i, doc in enumerate(response["docs"])
]
)
response["answer"] = formatted_text
return response

def add_knowledge(self, file_dir: str, enable_qa_extraction: bool):
q = dict(file_path=file_dir, enable_qa_extraction=enable_qa_extraction)
r = requests.post(self.load_knowledge_url, json=q)
r = requests.post(self.load_data_url, json=q)
r.raise_for_status()
return

Expand Down
1 change: 0 additions & 1 deletion src/pai_rag/app/web/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ACCURATE_CONTENT_PROMPTS,
PROMPT_MAP,
)
from os import environ

import logging
import traceback
Expand Down
8 changes: 3 additions & 5 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ViewModel(BaseModel):
fusion_mode: str = "reciprocal_rerank" # [simple, reciprocal_rerank, dist_based_score, relative_score]
query_rewrite_n: int = 1

synthesizer_type:str = None
synthesizer_type: str = None

text_qa_template: str = None

Expand Down Expand Up @@ -161,9 +161,7 @@ def sync_app_config(self, config):
self.chunk_size = config["node_parser"]["chunk_size"]
self.chunk_overlap = config["node_parser"]["chunk_overlap"]

self.reader_type = config["data_reader"].get(
"type", self.reader_type
)
self.reader_type = config["data_reader"].get("type", self.reader_type)
self.enable_qa_extraction = config["data_reader"].get(
"enable_qa_extraction", self.enable_qa_extraction
)
Expand Down Expand Up @@ -221,7 +219,7 @@ def to_app_config(self):
config["node_parser"]["type"] = self.parser_type
config["node_parser"]["chunk_size"] = int(self.chunk_size)
config["node_parser"]["chunk_overlap"] = int(self.chunk_overlap)

config["data_reader"]["enable_qa_extraction"] = self.enable_qa_extraction
config["data_reader"]["type"] = self.reader_type

Expand Down
6 changes: 3 additions & 3 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ dynaconf_merge = true
name = "pai_rag"
version = "0.1.1"

[rag.agent]
type = "react"

[rag.chat_engine]
type = "CondenseQuestionChatEngine"

Expand Down Expand Up @@ -60,8 +63,5 @@ query_rewrite_n = 1 # set to 1 to disable query generation
type = "SimpleSummarize"
text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n"

[rag.agent]
type = "react"

[rag.tool]
type = ["calculator"]
22 changes: 14 additions & 8 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,20 @@ def reload(self, config):
async def load_knowledge(self, file_dir, enable_qa_extraction=False):
await self.data_loader.load(file_dir, enable_qa_extraction)

async def aquery_vectordb(self, query: RetrievalQuery) -> RetrievalResponse:
async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RagResponse(answer="Empty query. Please input your question.")
return RetrievalResponse(docs=[])

session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
node_results = await self.retriever.aretrieve(query.question)

docs = [ContextDoc(text = score_node.node.get_content(), metadata=score_node.node.metadata, score=score_node.score)
docs = [
ContextDoc(
text=score_node.node.get_content(),
metadata=score_node.node.metadata,
score=score_node.score,
)
for score_node in node_results
]
return RetrievalResponse(docs=docs)
Expand Down Expand Up @@ -106,7 +111,7 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
LlmResponse
"""
if not query.question:
return RagResponse(answer="Empty query. Please input your question.")
return LlmResponse(answer="Empty query. Please input your question.")

session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
Expand All @@ -116,7 +121,7 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
response = await llm_chat_engine.achat(query.question)
self.chat_store.persist()
return LlmResponse(answer=response.response)

async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
"""Query answer from RAG App via web search asynchronously.
Expand All @@ -128,10 +133,11 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
Returns:
LlmResponse
"""
if not query.question:
return LlmResponse(answer="Empty query. Please input your question.")

session_id = correlation_id.get()
self.logger.info(
f"Get session ID: {session_id}."
)
self.logger.info(f"Get session ID: {session_id}.")
response = await self.agent.achat(query.question)
return LlmResponse(answer=response.response)

Expand Down
6 changes: 3 additions & 3 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
return await self.rag.aquery_llm(query)

@trace_correlation_id
async def aquery_vectordb(self, query: RetrievalQuery):
return await self.rag.aquery_vectordb(query)
async def aquery_retrieval(self, query: RetrievalQuery):
return await self.rag.aquery_retrieval(query)

@trace_correlation_id
async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
return await self.rag.aquery_agent(query)

@trace_correlation_id
async def batch_evaluate_retrieval_and_response(self, type):
return await self.rag.batch_evaluate_retrieval_and_response(type)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.modules.module_registry import module_registry
from llama_index.core.prompts.prompt_type import PromptType
Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(
num_questions_per_chunk=self.num_questions_per_chunk
)
self.show_progress = show_progress
self.is_test_run = os.getenv("IS_PAI_RAG_CI_TEST") == "true"
if self.is_test_run:
self.nodes = self.nodes[:1] # Only

logging.info("dataset generation initialized successfully.")

Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"ChatEngineFactoryModule",
"LlmChatEngineFactoryModule",
"AgentModule",
"ToolModule"
"ToolModule",
]

__all__ = ALL_MODULES + ["ALL_MODULES"]
Loading

0 comments on commit 1e0e488

Please sign in to comment.