Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support llm infer param: temperature #52

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 24 additions & 192 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ llama-index-llms-huggingface = "^0.2.0"
pytest-asyncio = "^0.23.7"
pytest-cov = "^5.0.0"
xlrd = "^2.0.1"
markdown = "^3.6"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
8 changes: 2 additions & 6 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@

class RagQuery(BaseModel):
question: str
topk: int | None = 3
topp: float | None = 0.8
temperature: float | None = 0.7
temperature: float | None = 0.1
vector_topk: int | None = 3
score_threshold: float | None = 0.5
chat_history: List[Dict[str, str]] | None = None


class LlmQuery(BaseModel):
question: str
topk: int | None = 3
topp: float | None = 0.8
temperature: float | None = 0.7
temperature: float | None = 0.1
chat_history: List[Dict[str, str]] | None = None


Expand Down
22 changes: 12 additions & 10 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any
import requests
import html
import markdown

cache_config = None

Expand Down Expand Up @@ -55,11 +57,9 @@ def query_llm(
self,
text: str,
session_id: str = None,
temperature: float = 0.7,
top_p: float = 0.8,
eas_llm_top_k: float = 30,
temperature: float = 0.1,
):
q = dict(question=text, topp=top_p, topk=eas_llm_top_k, temperature=temperature)
q = dict(question=text, temperature=temperature)

r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q)
r.raise_for_status()
Expand All @@ -76,12 +76,14 @@ def query_vector(self, text: str):
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id
formatted_text = "\n\n".join(
[
f"""[Doc {i+1}] [score: {doc["score"]}]\n{doc["text"]}"""
for i, doc in enumerate(response["docs"])
]
)
formatted_text = "<tr><th>Document</th><th>Score</th><th>Text</th></tr>\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果text本身是html或者带转义字符如之类的会出问题吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改:加了safe_html_content的转换,可以保证HTML和markdown格式的内容不会乱码。

for i, doc in enumerate(response["docs"]):
html_content = markdown.markdown(doc["text"])
safe_html_content = html.escape(html_content).replace("\n", "<br>")
formatted_text += '<tr style="font-size: 13px;"><td>Doc {}</td><td>{}</td><td>{}</td></tr>\n'.format(
i + 1, doc["score"], safe_html_content
)
formatted_text = "<table>\n<tbody>\n" + formatted_text + "</tbody>\n</table>"
response["answer"] = formatted_text
return response

Expand Down
44 changes: 21 additions & 23 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def create_chat_tab() -> Dict[str, Any]:
)

with gr.Column(visible=True) as vs_col:
vec_model_argument = gr.Accordion("Parameters of Vector Retrieval")
vec_model_argument = gr.Accordion(
"Parameters of Vector Retrieval", open=False
)

with vec_model_argument:
similarity_top_k = gr.Slider(
Expand Down Expand Up @@ -101,38 +103,22 @@ def create_chat_tab() -> Dict[str, Any]:
retrieval_mode,
}
with gr.Column(visible=True) as llm_col:
model_argument = gr.Accordion("Inference Parameters of LLM")
model_argument = gr.Accordion("Inference Parameters of LLM", open=False)
with model_argument:
include_history = gr.Checkbox(
label="Chat history",
info="Query with chat history.",
elem_id="include_history",
)
llm_topk = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=30,
elem_id="llm_topk",
label="Top K (choose between 0 and 100)",
)
llm_topp = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
value=0.8,
elem_id="llm_topp",
label="Top P (choose between 0 and 1)",
)
llm_temp = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
value=0.7,
elem_id="llm_temp",
step=0.001,
value=0.1,
elem_id="llm_temperature",
label="Temperature (choose between 0 and 1)",
)
llm_args = {llm_topk, llm_topp, llm_temp, include_history}
llm_args = {llm_temp, include_history}

with gr.Column(visible=True) as lc_col:
prm_type = gr.Radio(
Expand Down Expand Up @@ -198,26 +184,32 @@ def change_query_radio(query_type):
if query_type == "Retrieval":
return {
vs_col: gr.update(visible=True),
vec_model_argument: gr.update(open=True),
llm_col: gr.update(visible=False),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=False),
}
elif query_type == "LLM":
return {
vs_col: gr.update(visible=False),
vec_model_argument: gr.update(open=False),
llm_col: gr.update(visible=True),
model_argument: gr.update(open=True),
lc_col: gr.update(visible=False),
}
elif query_type == "RAG (Retrieval + LLM)":
return {
vs_col: gr.update(visible=True),
vec_model_argument: gr.update(open=False),
llm_col: gr.update(visible=True),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=True),
}

query_type.change(
fn=change_query_radio,
inputs=query_type,
outputs=[vs_col, llm_col, lc_col],
outputs=[vs_col, vec_model_argument, llm_col, model_argument, lc_col],
)

with gr.Column(scale=8):
Expand All @@ -239,6 +231,12 @@ def change_query_radio(query_type):
[question, chatbot, cur_tokens],
api_name="respond",
)
question.submit(
respond,
chat_args,
[question, chatbot, cur_tokens],
api_name="respond",
)
clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens])
return {
similarity_top_k.elem_id: similarity_top_k,
Expand Down
3 changes: 3 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ViewModel(BaseModel):
llm_eas_model_name: str = None
llm_api_key: str = None
llm_api_model_name: str = None
llm_temperature: float = 0.1

# chunking
parser_type: str = "Sentence"
Expand Down Expand Up @@ -115,6 +116,7 @@ def sync_app_config(self, config):
self.llm_eas_url = config["llm"].get("endpoint", self.llm_eas_url)
self.llm_eas_token = config["llm"].get("token", self.llm_eas_token)
self.llm_api_key = config["llm"].get("api_key", self.llm_api_key)
self.llm_temperature = config["llm"].get("temperature", self.llm_temperature)
if self.llm == "PaiEAS":
self.llm_eas_model_name = config["llm"].get("name", self.llm_eas_model_name)
else:
Expand Down Expand Up @@ -217,6 +219,7 @@ def to_app_config(self):
config["llm"]["endpoint"] = self.llm_eas_url
config["llm"]["token"] = self.llm_eas_token
config["llm"]["api_key"] = self.llm_api_key
config["llm"]["temperature"] = self.llm_temperature
if self.llm == "PaiEas":
config["llm"]["name"] = self.llm_eas_model_name
else:
Expand Down
19 changes: 13 additions & 6 deletions src/pai_rag/modules/llm/llm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
f"""
[Parameters][LLM:OpenAI]
model = {config.get("name", "gpt-3.5-turbo")},
temperature = {config.get("temperature", 0.5)},
temperature = {config.get("temperature", 0.1)},
system_prompt = {config.get("system_prompt", "Please answer in Chinese.")}
"""
)
llm = OpenAI(
model=config.get("name", "gpt-3.5-turbo"),
temperature=config.get("temperature", 0.5),
temperature=config.get("temperature", 0.1),
system_prompt=config.get("system_prompt", "Please answer in Chinese."),
api_key=config.get("api_key", None),
)
Expand All @@ -39,13 +39,13 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
f"""
[Parameters][LLM:AzureOpenAI]
model = {config.get("name", "gpt-35-turbo")},
temperature = {config.get("temperature", 0.5)},
temperature = {config.get("temperature", 0.1)},
system_prompt = {config.get("system_prompt", "Please answer in Chinese.")}
"""
)
llm = AzureOpenAI(
model=config.get("name", "gpt-35-turbo"),
temperature=config.get("temperature", 0.5),
temperature=config.get("temperature", 0.1),
system_prompt=config.get("system_prompt", "Please answer in Chinese."),
)
elif source == "dashscope":
Expand All @@ -56,7 +56,9 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
model = {model_name}
"""
)
llm = DashScope(model_name=model_name)
llm = DashScope(
model_name=model_name, temperature=config.get("temperature", 0.1)
)
elif source == "paieas":
model_name = config["name"]
endpoint = config["endpoint"]
Expand All @@ -69,7 +71,12 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
token = {token}
"""
)
llm = PaiEAS(endpoint=endpoint, token=token, model_name=model_name)
llm = PaiEAS(
endpoint=endpoint,
token=token,
model_name=model_name,
temperature=config.get("temperature", 0.1),
)
else:
raise ValueError(f"Unknown LLM source: '{config['llm']['source']}'")

Expand Down
Loading