We provide the implementation of C-RAG in this repositary.
C-RAG is the first framework to certify generation risks for RAG models. Specifically, C-RAG provides conformal risk analysis for RAG models and certify an upper confidence bound of generation risks, which is refered to as conformal generation risk. C-RAG also provides theoretical guarantees on conformal generation risks for general bounded risk functions under test distribution shifts. C-RAG proves that RAG achieves a lower conformal generation risk than that of a single LLM when the quality of the retrieval model and transformer is non-trivial. The intensive empirical results demonstrate the soundness and tightness of the conformal generation risk guarantees across four widely-used NLP datasets on four state-of-the-art retrieval models.
Install PyTorch with correponding environment and CUDA version at Pytorch Installation.
Run pip install -r requirement.txt
for installation of other neccessary packages in the repo.
For the supervised-finetuned biencoder-based retrieval model, we follow the implementation in LLM-R and provide the model checkpoint at trained_retrieval_models.
Or you can download it by command:
gdown https://drive.google.com/uc?id=1xOeCz3vt2piHuY00a5q4YCNhDkyCs0VF
Then, put the folder trained_retrieval_models/
under C-RAG/
.
We evaluate C-RAG on four widely used NLP datasets, including AESLC, CommonGen, DART, and E2E. We preprocess the data and provide it at data.
Or you can download it by command:
gdown https://drive.google.com/uc?id=1JJC192wdOmXYZy_hXcGVrXOtMK2LWsv7
Then, put the folder data/
under C-RAG/
.
To compute the conformal generation risk, we need to (1) evaluate the raw risk scores for calibration instances following our constrained generation protocol, and (2) compute the conformal generation risks based on empirical risk statistics.
Evaluate raw risk scores for BM25 retrieval model:
sh scripts_raw_risk_scores/bm25.sh
Evaluate raw risk scores for BAAI/bge retrieval model:
sh scripts_raw_risk_scores/baai.sh
Evaluate raw risk scores for OpenAI/text-embedding-ada-002 retrieval model:
sh scripts_raw_risk_scores/openai.sh
Evaluate raw risk scores for LLM-R finetuned biencoder-based retrieval model:
sh scripts_raw_risk_scores/llm-r.sh
- Prepare the prompt via
src/inference/generate_few_shot_prompt.py
:
Retrieve relevant examples and store the prompts atoutputs/{METHOD}/{METHOD}_test_k{N_RAG}.jsonl.gz
- Evaluate the risks of prompts on calibration sets via
src/conformal_calibration_empirical_risk.py
:
Evaluate the prompts and store results inoutputs/{METHOD}/{LLM}_{METHOD}/
The conformal generation risk computation is based on empirical risk statistics stored at outputs/{METHOD}/{LLM}_{METHOD}/
in step (1).
- Compute conformal generation risks of a single retrieval model and compare it with the simulation results:
sh scripts_conformal_generation_risk/run_conformal_generation_risk.sh
- Compare conformal generation risks of different retrieval models (after running step 1 for corresponding methods):
sh scripts_conformal_generation_risk/run_conformal_generation_risk_comparisons.sh
- Compute conformal generation risks of a single retrieval model and compare it with the simulation results:
sh scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk.sh
- Compare conformal generation risks of different retrieval models (after running step 1 for corresponding methods):
sh scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk_comparisons.sh
sh scripts_conformal_generation_risk/run_conformal_generation_risk_multi_dim_config.sh
sh scripts_conformal_generation_risk/run_conformal_generation_risk_valid_config.sh
Conformal generation risks with varying generation set sizes:
sh scripts_conformal_generation_risk/run_conformal_generation_risk_num_gen.sh
Conformal generation risks with varying generation similar thresholds:
sh scripts_conformal_generation_risk/run_conformal_generation_risk_similarity_threshold.sh
The inference part in the repo is built on LLM-R repo.
For any related questions or discussion, please contact [email protected]
.
If you find our paper and repo useful for your research, please consider cite:
@article{kang2024c,
title={C-RAG: Certified Generation Risks for Retrieval-Augmented Language Models},
author={Kang, Mintong and G{\"u}rel, Nezihe Merve and Yu, Ning and Song, Dawn and Li, Bo},
journal={arXiv preprint arXiv:2402.03181},
year={2024}
}