-
Notifications
You must be signed in to change notification settings - Fork 231
KeyError #115
Comments
Thank you. I have tried this solution before, but it didn't work. Maybe I should change my package version accoring to requirements.txt. |
@JLUGQQ: I am able to successfully run both eval on both zeshel and non-zeshel datasets. Feel free to copy and paste your error message here, I'd be glad to take a look. |
Thank you very much for your help! |
@JLUGQQ: Here's what I have in diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
oid = 0
for step, batch in enumerate(iter_):
batch = tuple(t.to(device) for t in batch)
- context_input, _, srcs, label_ids = batch
+ if is_zeshel:
+ context_input, _, srcs, label_ids = batch
+ else:
+ context_input, _, label_ids = batch
+ srcs = torch.tensor([0] * context_input.size(0), device=device)
+
src = srcs[0].item()
+ cand_encode_list[src] = cand_encode_list[src].to(device)
scores = reranker.score_candidate(
context_input,
None,
- cand_encs=cand_encode_list[src].to(device)
+ cand_encs=cand_encode_list[src]
)
+
values, indicies = scores.topk(top_k)
old_src = src
for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
continue
# add examples in new_data
- cur_candidates = candidate_pool[src][inds]
+ cur_candidates = candidate_pool[srcs[i].item()][inds]
nn_context.append(context_input[i].cpu().tolist())
nn_candidates.append(cur_candidates.cpu().tolist())
nn_labels.append(pointer) |
Pity. It still doesn't work. Thanks for your reply. I think I should take a time to debug to find the exact reason. And I will comment if I solve this problem. |
KeyError might happen because the validation or test set tries to find their encodings from training set encodings. (e.g. there is a crash when val data - which has the src value 9 - attempts to find their encoding in training encodings, which has src values from 0 to 8.-- it is the reason why there is a key error for value 9)
|
when i run eval_biencoder, i encountered this problem:Traceback (most recent call last):
File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 336, in
main(new_params)
File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 289, in main
save_results,
File "/data/gavin/blink-el/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions
cand_encs=cand_encode_list[src].to(device)
KeyError: 9
The text was updated successfully, but these errors were encountered: