Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RankingModel class for Text+Numr use case #299

Merged
merged 1 commit into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/build_pypi_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ echo "pip: $($PIP --version)"


# Install dependencies
# TODO: remove pin on setuptools after removing numpy.distutils
echo "Install dependencies..."
$PIP install setuptools wheel twine auditwheel
$PIP install 'setuptools<=73.0.1' wheel twine auditwheel

# Install OpenBLAS
# Using pre-build OpenBLAS lib v0.3.27 hosted on Anaconda
Expand Down
33 changes: 33 additions & 0 deletions examples/msmarco-rankllama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# PECOS XMR Reranker on MS-Marco Dataset

This is an example of PECOS-based RankingModel that reproduced the [RankLlaMA paper](https://arxiv.org/abs/2310.08319).

## How to run

### Training
```bash
torchrun --nnodes 1 --nproc-per-node 8 \
-m pecos.xmr.reranker.train \
--config_json_path ./msmarco_qwen2-7B.train.json
```

### Predictions
```bash
python -m pecos.xmr.reranker.predict \
--config_json_path ./msmarco_qwen2-7B.pred.json
```

## Evaluation
We first convert the predictions from parquet to TREC format:
```python
python -u parquet_to_trec_eval.py -i inference_outputs/ms_marco/qwen2-7B -o inference_outputs/ms_marco/qwen2-7B.pred.trec
```

We then follow [Pyserini]() evaluation protocol to eval the NDCG@10,
and you should see the results like:
```python
python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 dl19-passage inference_outputs/ms_marco/qwen2-7B.pred.trec

Results:
ndcg_cut_10 all 0.7619
```
21 changes: 21 additions & 0 deletions examples/msmarco-rankllama/msmarco_qwen2-7B.pred.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"target_data_folder": "./datasets/ms_marco/eval_aux/target",
"input_data_folder": "./datasets/ms_marco/eval_aux/input",
"label_data_folder": "./datasets/ms_marco/eval_aux/label",
"model_path": "./models/ms_marco/qwen2-7B/",
"output_dir": "./inference_outputs/ms_marco/qwen2-7B/",
"per_device_eval_batch_size": 1024,
"dataloader_num_workers": 1,
"dataloader_prefetch_factor": 10,
"rerank_max_len": 196,
"query_prefix": "query: ",
"passage_prefix": "document: ",
"inp_id_col": "inp_id",
"lbl_id_col": "lbl_id",
"inp_id_orig_col": "inp_id_orig",
"lbl_id_orig_col": "lbl_id_orig",
"keyword_col_name": "keywords",
"content_col_names": ["title", "contents"],
"append_eos_token": false,
"pad_to_multiple_of": 8
}
140 changes: 140 additions & 0 deletions examples/msmarco-rankllama/msmarco_qwen2-7B.train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"train_params": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams"
},
"target_data_folder": "./datasets/ms_marco/train/target",
"input_data_folder": "./datasets/ms_marco/train/input",
"label_data_folder": "./datasets/ms_marco/train/label",
"hf_trainer_args": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.trainer###RankingTrainer.TrainingArgs"
},
"output_dir": "./models/ms_marco/qwen2-7B",
"ddp_find_unused_parameters": false,
"loss_fn": "listwise",
"loss_alpha": 1.0,
"group_size": 16,
"per_device_train_batch_size": 6,
"gradient_accumulation_steps": 8,
"disable_tqdm": false,
"logging_strategy": "steps",
"logging_first_step": false,
"learning_rate": 1e-4,
"max_steps": 1500,
"save_steps": 50,
"logging_steps": 10,
"save_strategy": "steps",
"save_total_limit": 5,
"seed": 42,
"data_seed": 42,
"bf16": true,
"dataloader_num_workers": 2,
"dataloader_prefetch_factor": 10,
"gradient_checkpointing": true,
"deepseed": {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 1e6,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 10,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 10,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto",
"torch_adam": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 1000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
}
},
"model_params": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams"
},
"encoder_config": {
"text_config": {
"model_type": "qwen2",
"name_or_path": "Qwen/Qwen2-7B",
"attn_implementation": "sdpa",
"trust_remote_code": true,
"token": null
},
"numr_config": null,
"text_pooling_type": "last",
"head_size_list": [128]
},
"model_modifier": {
"modifier_type": "peft",
"config_type": "LoraConfig" ,
"config": {
"r": 16,
"lora_alpha": 32,
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
"modules_to_save": ["head_layers", "scorer"],
"lora_dropout": 0.1
}
},
"positive_passage_no_shuffle": false,
"negative_passage_no_shuffle": false,
"rerank_max_len": 196,
"query_prefix": "query: ",
"passage_prefix": "document: ",
"inp_id_col": "inp_id",
"lbl_idxs_col": "ret_idxs",
"score_col": "rel",
"keyword_col_name": "keywords",
"content_col_names": ["title", "contents"],
"append_eos_token": false,
"pad_to_multiple_of": 16
}
}
36 changes: 36 additions & 0 deletions examples/msmarco-rankllama/parquet_to_trec_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import argparse
import os
import pandas as pd


def main(args):
"""
Combine all results from the results folder and write them to the output file.
"""
result_files = [
os.path.join(args.input_parquet_path, x)
for x in os.listdir(args.input_parquet_path)
]
all_results = pd.read_parquet(result_files[0])
for f in result_files[1:]:
all_results = pd.concat([all_results, pd.read_parquet(f)])
# sort all results by 'inp_id' and then 'score' in descending order
all_results = all_results.sort_values(by=['inp_id', 'score'], ascending=[True, False])

cur_inp_id = None
with open(args.output_trec_path, "w") as fout:
for row in all_results.itertuples():
if cur_inp_id != row.inp_id:
cur_inp_id = row.inp_id
rank = 0
rank += 1
fout.write(f"{row.inp_id} Q0 {row.lbl_id} {rank} {row.score} dense\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-parquet-path", type=str, required=True)
parser.add_argument("-o", "--output-trec-path", type=str, required=True)
args = parser.parse_args()
main(args)
126 changes: 9 additions & 117 deletions pecos/xmr/reranker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,14 @@ This is a reranker for the PECOS XMR model. It is based on huggingface's transfo
single process and distributed mode. It is based on the paper [Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319).

## How to run
### Single process
To run the reranker in single process mode, you can use the following command:

```bash
python -m pecos.xmr.reranker.train --config_json_path <path_to_config_file>
```

### Distributed mode
To run the reranker in distributed mode, you can use the following command to initialize the distributed configuration:
```bash
accelerate config
```
### Training
To train the reranker, we suggest to use the `torchrun` command:

Then you can run the reranker using the following command:
```bash
accelerate launch -m pecos.xmr.reranker.train --config_json_path <path_to_config_file>
torchrun --nnodes 1 --nproc-per-node 8 \
-m pecos.xmr.reranker.train \
--config_json_path <path_to_config_file>
```

### Predictions
Expand All @@ -28,112 +20,12 @@ To run the reranker in prediction mode, you can use the following command:
python -m pecos.xmr.reranker.predict --config_json_path <path_to_config_file>
```

## Configuration file

### Training
Here is an example of the configuration file for training:
```json
{
"train_params": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###RankingModel.TrainParams"
},
"target_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/target",
"input_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/input",
"label_data_folder": "/home/ec2-user/docker_disk/datasets/ms_marco_partitioned/label",
"training_args": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.trainer###RankLlamaTrainer.TrainingArgs"
},
"learning_rate": 1e-4,
"output_dir": "./ds_model",
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 8,
"max_steps": -1,
"logging_strategy": "steps",
"logging_first_step": false,
"logging_steps": 10,
"save_strategy": "steps",
"save_steps": 50,
"save_total_limit": 5,
"seed": 42,
"data_seed": 42,
"bf16": true,
"dataloader_num_workers": 2,
"dataloader_prefetch_factor": 10,
"gradient_checkpointing": true,
"train_group_size": 16
}
},
"model_params": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###RankingModel.ModelParams"
},
"encoder_args": {
"__meta__": {
"class_fullname": "pecos.xmr.reranker.model###CrossEncoder.Config"
},
"model_shortcut": "meta-llama/Llama-2-7b-hf",
"model_init_kwargs": {},
"model_modifier": {
"modifier_type": "peft",
"config_type": "LoraConfig" ,
"config": {
"r": 8,
"lora_alpha": 64,
"target_modules": ["q_proj", "v_proj"],
"modules_to_save": ["score", "classifier"],
"lora_dropout": 0.1
}
}
},
"positive_passage_no_shuffle": false,
"negative_passage_no_shuffle": false,
"rerank_max_len": 196,
"query_prefix": "query: ",
"passage_prefix": "document: ",
"inp_id_col": "inp_id",
"lbl_idxs_col": "ret_idxs",
"score_col": "rel",
"keyword_col_name": "keywords",
"content_col_names": ["title", "contents"],
"append_eos_token": false,
"pad_to_multiple_of": 16
}
}
```

### Prediction
Following is the example of the configuration file for prediction:
```json
{
"model_name_or_path": "/tmp/pecosdev/ds_model",
"target_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/target",
"input_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/input",
"label_data_folder": "/home/ec2-user/docker_disk/datasets/msmarcoeval/label",
"output_dir": "/tmp/xmrout",
"per_device_eval_batch_size": 512,
"dataloader_num_workers": 1,
"dataloader_prefetch_factor": 10,
"rerank_max_len": 196,
"query_prefix": "query: ",
"passage_prefix": "document: ",
"inp_id_col": "inp_id",
"lbl_id_col": "lbl_id",
"keyword_col_name": "keywords",
"content_col_names": ["title", "contents"],
"append_eos_token": false,
"pad_to_multiple_of": 8,
"device": "cuda",
"model_init_kwargs": {
"device_map": "auto"
}
}
```
## Config JSON Files
See example training/predict JSON files in `pecos/examples/msmarco-rankllama` folders.

## Data Schema
The column names for the data schema are configurable through the json configuration file. Following
are the various schemas that are supported by the reranker:
The column names for the data schema are configurable through the json configuration file.
Following are the various schemas that are supported by the reranker:

(1) Learning Target Schema
```
Expand Down
Loading
Loading