-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
146 lines (127 loc) · 5.84 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# general
import os
import logging
import time
import asyncio
import warnings
from dotenv import load_dotenv
warnings.filterwarnings("ignore", category=DeprecationWarning)
# rag specific
from rag.parser import DocParser, ParseConfig
from rag.embedding import Embedding
from rag.model import Model
from rag.indexer import Indexer
from rag.retriever import Retriever
from rag.evaluate import QADataset, Evaluator
from rag.common.constants import (
CHUNKING_STRATEGIES,
EMBED_PROVIDERS,
LLM_MODELS,
RETRIEVAL_STRATEGIES,
EVAL_MODEL,
DATASET_PATH,
STORE_PATH,
QA_DATASET_PATH,
)
# misc
import pandas as pd
from llama_index import ServiceContext
# NOTE: make sure .env has all API keys as per README.md
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
async def eval():
"""Main entry function to run evaluation."""
all_results = []
start = time.time()
logging.info(f"Starting evaluation ..")
try:
for chunking_strategy in CHUNKING_STRATEGIES:
# 1. load, parse, and chunk input documents into nodes
doc_parser = DocParser(config=ParseConfig())
docs = doc_parser.load_docs(DATASET_PATH)
logging.info(f"Applying chunking strategy: {chunking_strategy} ..")
nodes = doc_parser.get_nodes(docs=docs, chunking_strategy=chunking_strategy)
logging.info(f"Number of nodes: {len(nodes)}")
# 1. generate or load the QA dataset
if os.path.exists(QA_DATASET_PATH):
logging.info("QA dataset exists ..")
qa_dataset = QADataset.load_dataset(path=QA_DATASET_PATH)
else:
logging.info("QA dataset does not exist. Generating ..")
os.makedirs("./qa_dataset", exist_ok=True)
qa_dataset = QADataset.generate_dataset(
nodes=nodes,
path=QA_DATASET_PATH,
)
logging.info(f"Loaded QA dataset ..")
for embed_provider in EMBED_PROVIDERS:
embed_model = Embedding(embed_provider=embed_provider).model
for llm_model in LLM_MODELS:
llm = Model(model_name=llm_model).model
logging.info(
f"Initialized LLM: {llm_model.value}, Embedding: {embed_provider.value} .."
)
for retrieval_strategy in RETRIEVAL_STRATEGIES:
slug = f"{chunking_strategy.value}_{embed_provider.value}_{llm_model.value}_{retrieval_strategy.value}"
logging.info(f"Starting evaluation: {slug}")
# 2. generate vector index from nodes
service_context = ServiceContext.from_defaults(
llm=llm, embed_model=embed_model
)
indexer = Indexer(service_context)
vector_index = indexer.get_vector_index(
nodes=nodes, store_dir=f"{STORE_PATH}/{slug}"
)
logging.info(f"Generated vector index ..")
# 3. create retriever and query engine
retriever = Retriever(
nodes=nodes,
vector_index=vector_index,
chunking_strategy=chunking_strategy,
retrieval_strategy=retrieval_strategy,
service_context=service_context,
)
retriever_chunk = retriever.get_retriever()
query_engine = retriever.get_query_engine()
logging.info(f"Initialized retriever ..")
# 4. evaluate RAG performance and compute metrics
llm_eval_model = Model(model_name=EVAL_MODEL).model
eval_service_context = ServiceContext.from_defaults(
llm=llm_eval_model
)
evaluator = Evaluator(service_context=eval_service_context)
# 4.a evaluate retrieval (metrics: MRR, Hit Rate)
retrieval_results = await evaluator.evaluate_retrieval(
qa_dataset=qa_dataset,
retriever=retriever_chunk,
)
# 4.b evaluate response (metrics: faithulness, relevancy)
response_results = await evaluator.evaluate_response(
qa_dataset=qa_dataset,
query_engine=query_engine,
max_queries=10,
)
# consolidate results
metrics = evaluator.get_eval_metrics(
name=slug,
chunking_strategy=chunking_strategy.value,
embed_provider=embed_provider.value,
llm_model=llm_model.value,
retrieval_strategy=retrieval_strategy.value,
retrieval_results=retrieval_results,
response_results=response_results,
)
all_results.append(metrics)
logging.info(f"Finished evaluation.")
except Exception as e:
logging.error(f"Error evaluating: {slug} - {e}")
raise Exception(f"Error evaluating: {slug}") from e
all_results_df = pd.concat(all_results)
all_results_df.to_csv("./results/all_results.csv", index=False) # export results
logging.info(f"Total time taken: {(time.time() - start)/60:.2f} mins.")
if __name__ == "__main__":
asyncio.run(eval())