-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
081db58
commit be6a3cf
Showing
1 changed file
with
396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,396 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7e2d65ff-eea0-44bd-a736-9675ae28ad9c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install --upgrade git+https://github.com/terrierteam/pyterrier_dr.git\n", | ||
"import pyterrier as pt\n", | ||
"pt.utils.set_tqdm('notebook')\n", | ||
"import pyterrier_dr" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bd1ae81d-21fe-49d1-b282-a06bc6a53ce7", | ||
"metadata": {}, | ||
"source": [ | ||
"Load your index and the relevant model\n", | ||
"\n", | ||
"If you dont have an existing TCT index, one can be created for MSMARCO as follows:\n", | ||
"```python\n", | ||
"index = pyterrier_dr.FlexIndex('myindex.flex')\n", | ||
"idx_pipeline = model >> index\n", | ||
"idx_pipeline.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter())\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "9dc40df5-9898-4888-a2c1-72c27f3d7bb7", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "8d57b49b48444edcad1957c10465e3db", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"tokenizer_config.json: 0%| | 0.00/334 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "eb0c8d86392843089c0f28fafd0574b9", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"config.json: 0%| | 0.00/559 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "4772adacafa14c6194569fb48358c0a7", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "30fc505f140d482baceff1e46ba0630e", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"special_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "481d4491e9bb42d381244c3130254df1", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/miniconda3/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", | ||
" return self.fget.__get__(instance, owner)()\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"index_loc = '/nfs/global_indices/msmarco-passage.tct-hnp.flex'\n", | ||
"model = pyterrier_dr.TctColBert('castorini/tct_colbert-v2-hnp-msmarco')\n", | ||
"index = pyterrier_dr.FlexIndex(index_loc)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "46556e84-b895-47b4-b439-8016cb4081cd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# standard dense retrieval\n", | ||
"baseline = model >> index" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "75d84842-93c8-455a-9c96-e1d74f166298", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# PRF pipelines\n", | ||
"vector_prf = model >> index % 3 >> index.vec_loader() >> pyterrier_dr.VectorPrf() >> index\n", | ||
"average_prf = model >> index % 3 >> index.vec_loader() >> pyterrier_dr.AveragePrf() >> index" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"id": "389f1b82-9a81-4e76-8d43-8638ee24bd9f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Java started (triggered by _read_topics_singleline) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.10 (build: craigm 2024-08-22 17:33), helper_version=0.0.8]\n", | ||
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n", | ||
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "2e183ef4c5a3493dba2ed54d3600ba3a", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n", | ||
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "a6551a6965c54aafbc59c45de6609334", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n", | ||
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "8a828590e0da4fba96609bb82cd48446", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n", | ||
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "c5c6c1658fc04fe2bc48e158ca6eb150", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n", | ||
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "35514c524f7747ee8d01504ca6f25adc", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>name</th>\n", | ||
" <th>nDCG@10</th>\n", | ||
" <th>AP(rel=2)@100</th>\n", | ||
" <th>nDCG@10 +</th>\n", | ||
" <th>nDCG@10 -</th>\n", | ||
" <th>nDCG@10 p-value</th>\n", | ||
" <th>AP(rel=2)@100 +</th>\n", | ||
" <th>AP(rel=2)@100 -</th>\n", | ||
" <th>AP(rel=2)@100 p-value</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>TCT</td>\n", | ||
" <td>0.720577</td>\n", | ||
" <td>0.402403</td>\n", | ||
" <td>NaN</td>\n", | ||
" <td>NaN</td>\n", | ||
" <td>NaN</td>\n", | ||
" <td>NaN</td>\n", | ||
" <td>NaN</td>\n", | ||
" <td>NaN</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>TCT >> VectorPRF</td>\n", | ||
" <td>0.732201</td>\n", | ||
" <td>0.425013</td>\n", | ||
" <td>24.0</td>\n", | ||
" <td>11.0</td>\n", | ||
" <td>0.203288</td>\n", | ||
" <td>30.0</td>\n", | ||
" <td>12.0</td>\n", | ||
" <td>0.000902</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>2</th>\n", | ||
" <td>TCT >> AveragePRF</td>\n", | ||
" <td>0.729361</td>\n", | ||
" <td>0.441232</td>\n", | ||
" <td>25.0</td>\n", | ||
" <td>14.0</td>\n", | ||
" <td>0.505680</td>\n", | ||
" <td>28.0</td>\n", | ||
" <td>14.0</td>\n", | ||
" <td>0.032259</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" name nDCG@10 AP(rel=2)@100 nDCG@10 + nDCG@10 - \\\n", | ||
"0 TCT 0.720577 0.402403 NaN NaN \n", | ||
"1 TCT >> VectorPRF 0.732201 0.425013 24.0 11.0 \n", | ||
"2 TCT >> AveragePRF 0.729361 0.441232 25.0 14.0 \n", | ||
"\n", | ||
" nDCG@10 p-value AP(rel=2)@100 + AP(rel=2)@100 - AP(rel=2)@100 p-value \n", | ||
"0 NaN NaN NaN NaN \n", | ||
"1 0.203288 30.0 12.0 0.000902 \n", | ||
"2 0.505680 28.0 14.0 0.032259 " | ||
] | ||
}, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"import pyterrier as pt\n", | ||
"from pyterrier.measures import *\n", | ||
"pt.Experiment(\n", | ||
" [baseline, vector_prf, average_prf],\n", | ||
" pt.get_dataset('msmarco_passage').get_topics('test-2019'),\n", | ||
" pt.get_dataset('msmarco_passage').get_qrels('test-2019'),\n", | ||
" [nDCG@10, AP(rel=2)@100],\n", | ||
" names=[\"TCT\", \"TCT >> VectorPRF\", \"TCT >> AveragePRF\"],\n", | ||
" baseline=0\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0ecfd01d-7365-4f04-a3f1-563db0c5f920", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |