Skip to content

Commit

Permalink
Do everything in token space
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Dec 3, 2024
1 parent 37f9db2 commit 5d8ec19
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions chai_lab/data/features/generators/token_bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self):
def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict:
return dict(
token_exists_mask=batch["inputs"]["token_exists_mask"],
atom_token_index=batch["inputs"]["atom_token_index"].long(),
# atom_token_index=batch["inputs"]["atom_token_index"].long(),
atom_covalent_bond_indices=batch["inputs"]["atom_covalent_bond_indices"],
)

Expand All @@ -40,7 +40,7 @@ def apply_mask(self, feature: Tensor, mask: Tensor, mask_ty: FeatureType) -> Ten
def _generate(
self,
token_exists_mask: Bool[Tensor, "b n"],
atom_token_index: Int[Tensor, "b a"],
# atom_token_index: Int[Tensor, "b a"],
atom_covalent_bond_indices: list[
tuple[Int[Tensor, "bonds"], Int[Tensor, "bonds"]]
],
Expand All @@ -51,12 +51,14 @@ def _generate(
for batch_idx, (left_indices, right_indices) in enumerate(
atom_covalent_bond_indices
):
left_token_indices = torch.gather(
atom_token_index[batch_idx], dim=0, index=left_indices
)
right_token_indices = torch.gather(
atom_token_index[batch_idx], dim=0, index=right_indices
)
bond_feature[batch_idx][left_token_indices, right_token_indices] = 1
bond_feature[batch_idx][left_indices, right_indices] = 1
# left_token_indices = torch.gather(
# atom_token_index[batch_idx], dim=0, index=left_indices
# )
# right_token_indices = torch.gather(
# atom_token_index[batch_idx], dim=0, index=right_indices
# )
# print(left_token_indices, right_token_indices)
# bond_feature[batch_idx][left_token_indices, right_token_indices] = 1

return self.make_feature(bond_feature.unsqueeze(-1))

0 comments on commit 5d8ec19

Please sign in to comment.