Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCT PRF notebook #29

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 396 additions & 0 deletions examples/tct-prf.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "7e2d65ff-eea0-44bd-a736-9675ae28ad9c",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade git+https://github.com/terrierteam/pyterrier_dr.git\n",
"import pyterrier as pt\n",
"pt.utils.set_tqdm('notebook')\n",
"import pyterrier_dr"
]
},
{
"cell_type": "markdown",
"id": "bd1ae81d-21fe-49d1-b282-a06bc6a53ce7",
"metadata": {},
"source": [
"Load your index and the relevant model\n",
"\n",
"If you dont have an existing TCT index, one can be created for MSMARCO as follows:\n",
"```python\n",
"index = pyterrier_dr.FlexIndex('myindex.flex')\n",
"idx_pipeline = model >> index\n",
"idx_pipeline.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter())\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9dc40df5-9898-4888-a2c1-72c27f3d7bb7",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8d57b49b48444edcad1957c10465e3db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/334 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb0c8d86392843089c0f28fafd0574b9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/559 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4772adacafa14c6194569fb48358c0a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30fc505f140d482baceff1e46ba0630e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "481d4491e9bb42d381244c3130254df1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
}
],
"source": [
"index_loc = '/nfs/global_indices/msmarco-passage.tct-hnp.flex'\n",
"model = pyterrier_dr.TctColBert('castorini/tct_colbert-v2-hnp-msmarco')\n",
"index = pyterrier_dr.FlexIndex(index_loc)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "46556e84-b895-47b4-b439-8016cb4081cd",
"metadata": {},
"outputs": [],
"source": [
"# standard dense retrieval\n",
"baseline = model >> index"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "75d84842-93c8-455a-9c96-e1d74f166298",
"metadata": {},
"outputs": [],
"source": [
"# PRF pipelines\n",
"vector_prf = model >> index % 3 >> index.vec_loader() >> pyterrier_dr.VectorPrf() >> index\n",
"average_prf = model >> index % 3 >> index.vec_loader() >> pyterrier_dr.AveragePrf() >> index"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "389f1b82-9a81-4e76-8d43-8638ee24bd9f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Java started (triggered by _read_topics_singleline) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.10 (build: craigm 2024-08-22 17:33), helper_version=0.0.8]\n",
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2e183ef4c5a3493dba2ed54d3600ba3a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6551a6965c54aafbc59c45de6609334",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8a828590e0da4fba96609bb82cd48446",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c5c6c1658fc04fe2bc48e158ca6eb150",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "35514c524f7747ee8d01504ca6f25adc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>nDCG@10</th>\n",
" <th>AP(rel=2)@100</th>\n",
" <th>nDCG@10 +</th>\n",
" <th>nDCG@10 -</th>\n",
" <th>nDCG@10 p-value</th>\n",
" <th>AP(rel=2)@100 +</th>\n",
" <th>AP(rel=2)@100 -</th>\n",
" <th>AP(rel=2)@100 p-value</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>TCT</td>\n",
" <td>0.720577</td>\n",
" <td>0.402403</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>TCT &gt;&gt; VectorPRF</td>\n",
" <td>0.732201</td>\n",
" <td>0.425013</td>\n",
" <td>24.0</td>\n",
" <td>11.0</td>\n",
" <td>0.203288</td>\n",
" <td>30.0</td>\n",
" <td>12.0</td>\n",
" <td>0.000902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>TCT &gt;&gt; AveragePRF</td>\n",
" <td>0.729361</td>\n",
" <td>0.441232</td>\n",
" <td>25.0</td>\n",
" <td>14.0</td>\n",
" <td>0.505680</td>\n",
" <td>28.0</td>\n",
" <td>14.0</td>\n",
" <td>0.032259</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name nDCG@10 AP(rel=2)@100 nDCG@10 + nDCG@10 - \\\n",
"0 TCT 0.720577 0.402403 NaN NaN \n",
"1 TCT >> VectorPRF 0.732201 0.425013 24.0 11.0 \n",
"2 TCT >> AveragePRF 0.729361 0.441232 25.0 14.0 \n",
"\n",
" nDCG@10 p-value AP(rel=2)@100 + AP(rel=2)@100 - AP(rel=2)@100 p-value \n",
"0 NaN NaN NaN NaN \n",
"1 0.203288 30.0 12.0 0.000902 \n",
"2 0.505680 28.0 14.0 0.032259 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"import pyterrier as pt\n",
"from pyterrier.measures import *\n",
"pt.Experiment(\n",
" [baseline, vector_prf, average_prf],\n",
" pt.get_dataset('msmarco_passage').get_topics('test-2019'),\n",
" pt.get_dataset('msmarco_passage').get_qrels('test-2019'),\n",
" [nDCG@10, AP(rel=2)@100],\n",
" names=[\"TCT\", \"TCT >> VectorPRF\", \"TCT >> AveragePRF\"],\n",
" baseline=0\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ecfd01d-7365-4f04-a3f1-563db0c5f920",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}