Skip to content

Commit

Permalink
low-memory regime (#172)
Browse files Browse the repository at this point in the history
* low-memory regime

* add low memory mode for bond_loss_input_proj

* remove torch.device, because run_inference is exposed by typer
  • Loading branch information
arogozhnikov authored Dec 9, 2024
1 parent cd61b3b commit 71eff6a
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,24 @@ class ModuleWrapper:
def __init__(self, jit_module):
self.jit_module = jit_module

def forward(self, crop_size: int, **kw):
return getattr(self.jit_module, f"forward_{crop_size}")(**kw)
def forward(
self,
crop_size: int,
*,
return_on_cpu=False,
move_to_device: torch.device | None = None,
**kw,
):
f = getattr(self.jit_module, f"forward_{crop_size}")
if move_to_device is not None:
result = f(**move_data_to_device(kw, device=move_to_device))
else:
result = f(**kw)

if return_on_cpu:
return move_data_to_device(result, device=torch.device("cpu"))
else:
return result


def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper:
Expand Down Expand Up @@ -295,6 +311,7 @@ def run_inference(
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
) -> StructureCandidates:
if output_dir.exists():
assert not any(
Expand Down Expand Up @@ -421,6 +438,7 @@ def run_inference(
num_diffn_timesteps=num_diffn_timesteps,
seed=seed,
device=torch_device,
low_memory=low_memory,
)


Expand All @@ -438,6 +456,7 @@ def run_folding_on_context(
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: torch.device | None = None,
low_memory: bool,
) -> StructureCandidates:
"""
Function for in-depth explorations.
Expand Down Expand Up @@ -479,7 +498,8 @@ def run_folding_on_context(
batch_size = len(feature_contexts)
batch = collator(feature_contexts)

batch = move_data_to_device(batch, device=device)
if not low_memory:
batch = move_data_to_device(batch, device=device)

# Get features and inputs from batch
features = {name: feature for name, feature in batch["features"].items()}
Expand Down Expand Up @@ -516,7 +536,12 @@ def run_folding_on_context(
## Run the features through the feature embedder
##

embedded_features = feature_embedding.forward(crop_size=model_size, **features)
embedded_features = feature_embedding.forward(
crop_size=model_size,
move_to_device=device,
return_on_cpu=low_memory,
**features,
)
token_single_input_feats = embedded_features["TOKEN"]
token_pair_input_feats, token_pair_structure_input_feats = embedded_features[
"TOKEN_PAIR"
Expand All @@ -538,7 +563,10 @@ def run_folding_on_context(
bond_ft_gen = TokenBondRestraint()
bond_ft = bond_ft_gen.generate(batch=batch).data
trunk_bond_feat, structure_bond_feat = bond_loss_input_proj.forward(
crop_size=model_size, input=bond_ft
return_on_cpu=low_memory,
move_to_device=device,
crop_size=model_size,
input=bond_ft,
).chunk(2, dim=-1)
token_pair_input_feats += trunk_bond_feat
token_pair_structure_input_feats += structure_bond_feat
Expand All @@ -548,6 +576,8 @@ def run_folding_on_context(
##

token_input_embedder_outputs: tuple[Tensor, ...] = token_input_embedder.forward(
return_on_cpu=low_memory,
move_to_device=device,
token_single_input_feats=token_single_input_feats,
token_pair_input_feats=token_pair_input_feats,
atom_single_input_feats=atom_single_input_feats,
Expand All @@ -573,6 +603,7 @@ def run_folding_on_context(
token_pair_trunk_repr = token_pair_initial_repr
for _ in tqdm(range(num_trunk_recycles), desc="Trunk recycles"):
(token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward(
move_to_device=device,
token_single_trunk_initial_repr=token_single_initial_repr,
token_pair_trunk_initial_repr=token_pair_initial_repr,
token_single_trunk_repr=token_single_trunk_repr, # recycled
Expand All @@ -593,27 +624,36 @@ def run_folding_on_context(
## Denoise the trunk representation by passing it through the diffusion module
##

atom_single_mask = atom_single_mask.to(device)

static_diffusion_inputs = dict(
token_single_initial_repr=token_single_structure_input.float(),
token_pair_initial_repr=token_pair_structure_input_feats.float(),
token_single_trunk_repr=token_single_trunk_repr.float(),
token_pair_trunk_repr=token_pair_trunk_repr.float(),
atom_single_input_feats=atom_single_structure_input_feats.float(),
atom_block_pair_input_feats=block_atom_pair_structure_input_feats.float(),
atom_single_mask=atom_single_mask,
atom_block_pair_mask=block_atom_pair_mask,
token_single_mask=token_single_mask,
block_indices_h=block_indices_h,
block_indices_w=block_indices_w,
atom_token_indices=atom_token_indices,
)
static_diffusion_inputs = move_data_to_device(
static_diffusion_inputs, device=device
)

def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
atom_noised_coords = rearrange(
atom_pos, "(b s) ... -> b s ...", s=s
).contiguous()
noise_sigma = repeat(sigma, " -> b s", b=batch_size, s=s)
return diffusion_module.forward(
token_single_initial_repr=token_single_structure_input.float(),
token_pair_initial_repr=token_pair_structure_input_feats.float(),
token_single_trunk_repr=token_single_trunk_repr.float(),
token_pair_trunk_repr=token_pair_trunk_repr.float(),
atom_single_input_feats=atom_single_structure_input_feats.float(),
atom_block_pair_input_feats=block_atom_pair_structure_input_feats.float(),
atom_single_mask=atom_single_mask,
atom_block_pair_mask=block_atom_pair_mask,
token_single_mask=token_single_mask,
block_indices_h=block_indices_h,
block_indices_w=block_indices_w,
atom_noised_coords=atom_noised_coords.float(),
noise_sigma=noise_sigma.float(),
atom_token_indices=atom_token_indices,
crop_size=model_size,
**static_diffusion_inputs,
)

num_diffn_samples = 5 # Fixed at export time
Expand Down Expand Up @@ -681,7 +721,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)

# We won't be running diffusion anymore
del diffusion_module
del diffusion_module, static_diffusion_inputs
torch.cuda.empty_cache()

##
Expand All @@ -690,6 +730,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:

confidence_outputs: list[tuple[Tensor, ...]] = [
confidence_head.forward(
move_to_device=device,
token_single_input_repr=token_single_initial_repr,
token_single_trunk_repr=token_single_trunk_repr,
token_pair_trunk_repr=token_pair_trunk_repr,
Expand Down

0 comments on commit 71eff6a

Please sign in to comment.