From 96d21068c9f02c5844d2195d0737ebb4089e8260 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Sat, 23 Nov 2024 22:07:00 +0000 Subject: [PATCH] whoops --- pyterrier_dr/biencoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyterrier_dr/biencoder.py b/pyterrier_dr/biencoder.py index 1430748..9e7684e 100644 --- a/pyterrier_dr/biencoder.py +++ b/pyterrier_dr/biencoder.py @@ -72,7 +72,7 @@ def encode(self, texts, batch_size=None) -> np.array: return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - pta.validate.columns(includes=['query']) + pta.validate.columns(inp, includes=['query']) it = inp['query'].values it, inv = np.unique(it, return_inverse=True) if self.verbose: @@ -95,7 +95,7 @@ def encode(self, texts, batch_size=None) -> np.array: return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: - pta.validate.columns(includes=[self.text_field]) + pta.validate.columns(inp, includes=[self.text_field]) it = inp[self.text_field] if self.verbose: it = pt.tqdm(it, desc='Encoding Docs', unit='doc')