-
Notifications
You must be signed in to change notification settings - Fork 0
/
rag_component.py
58 lines (46 loc) · 2.06 KB
/
rag_component.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import logging
import ollama
import asyncio
from concurrent.futures import ThreadPoolExecutor
from language_utils import get_translation
from langdetect import detect
logger = logging.getLogger(__name__)
class RAGComponent:
def __init__(self, model_name="llama3.1:latest"):
self.model_name = model_name
self._executor = None
logger.info(get_translation("rag_component_initialized").format(model=self.model_name))
@property
def executor(self):
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="RAG_Worker")
return self._executor
async def _generate_answer(self, query, context_chunks):
try:
lang = detect(query)
logger.info(get_translation("generating_answer_for_query").format(query=query))
context = "\n".join(chunk['content'] for chunk in context_chunks)
prompt = get_translation("rag_prompt").format(context=context, query=query)
response = await asyncio.to_thread(
lambda: ollama.chat(
model=self.model_name,
messages=[{'role': 'user', 'content': prompt}]
)
)
if response and 'message' in response:
answer = response['message']['content']
logger.info(get_translation("answer_generated_successfully"))
return answer
logger.warning(get_translation("no_valid_response_from_model"))
return get_translation("couldnt_generate_answer")
except Exception as e:
logger.error(
get_translation("error_generating_answer").format(error=str(e)),
exc_info=True
)
return get_translation("error_occurred_while_generating").format(error=str(e))
async def generate_answer(self, query, context_chunks):
return await self._generate_answer(query, context_chunks)
def __del__(self):
if self._executor:
self._executor.shutdown(wait=False)