diff --git a/finetune/dataset.py b/finetune/dataset.py index dbf6bbc..74b3f05 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -143,7 +143,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) context = torch.from_numpy(np.hstack(context, dtype=np.int8)) - if input_ids.shape[-1] > max_length: + if ids.shape[-1] > max_length: ids =ids[:max_length] context = context[:max_length] logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")