Skip to content

Commit

Permalink
arbitrary num_diffn_samples + use ds to enumerate diffusion samples (#…
Browse files Browse the repository at this point in the history
…260)

* use ds to enumerate diffusion samples

* expose number of diffusion samples, use ds as axis for diffusion samples
  • Loading branch information
arogozhnikov authored Dec 24, 2024
1 parent 550aa0f commit 4097a06
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def run_inference(
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
num_diffn_samples: int = 5,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
Expand Down Expand Up @@ -470,6 +471,7 @@ def run_inference(
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
num_diffn_samples=num_diffn_samples,
seed=seed,
device=torch_device,
low_memory=low_memory,
Expand All @@ -488,6 +490,8 @@ def run_folding_on_context(
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
# all diffusion samples come from the same trunk
num_diffn_samples: int = 5,
seed: int | None = None,
device: torch.device | None = None,
low_memory: bool,
Expand Down Expand Up @@ -678,19 +682,19 @@ def run_folding_on_context(
static_diffusion_inputs, device=device
)

def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
# verified manually that ds dimension can be arbitrary in diff module
atom_noised_coords = rearrange(
atom_pos, "(b s) ... -> b s ...", s=s
atom_pos, "(b ds) ... -> b ds ...", ds=ds
).contiguous()
noise_sigma = repeat(sigma, " -> b s", b=batch_size, s=s)
noise_sigma = repeat(sigma, " -> b ds", b=batch_size, ds=ds)
return diffusion_module.forward(
atom_noised_coords=atom_noised_coords.float(),
noise_sigma=noise_sigma.float(),
crop_size=model_size,
**static_diffusion_inputs,
)

num_diffn_samples = 5 # Fixed at export time
inference_noise_schedule = InferenceNoiseSchedule(
s_max=DiffusionConfig.S_tmax,
s_min=4e-4,
Expand Down Expand Up @@ -722,8 +726,8 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
atom_pos,
atom_single_mask=repeat(
atom_single_mask,
"b a -> (b s) a",
s=num_diffn_samples,
"b a -> (b ds) a",
ds=num_diffn_samples,
),
)

Expand All @@ -739,7 +743,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
denoised_pos = _denoise(
atom_pos=atom_pos_hat,
sigma=sigma_hat,
s=num_diffn_samples,
ds=num_diffn_samples,
)
d_i = (atom_pos_hat - denoised_pos) / sigma_hat
atom_pos = atom_pos_hat + (sigma_next - sigma_hat) * d_i
Expand All @@ -749,7 +753,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
denoised_pos = _denoise(
atom_pos,
sigma=sigma_next,
s=num_diffn_samples,
ds=num_diffn_samples,
)
d_i_prime = (atom_pos - denoised_pos) / sigma_next
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)
Expand All @@ -770,13 +774,13 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
token_pair_trunk_repr=token_pair_trunk_repr,
token_single_mask=token_single_mask,
atom_single_mask=atom_single_mask,
atom_coords=atom_pos[s : s + 1],
atom_coords=atom_pos[ds : ds + 1],
token_reference_atom_index=token_reference_atom_index,
atom_token_index=atom_token_indices,
atom_within_token_index=atom_within_token_index,
crop_size=model_size,
)
for s in range(num_diffn_samples)
for ds in range(num_diffn_samples)
]

pae_logits, pde_logits, plddt_logits = [
Expand Down

0 comments on commit 4097a06

Please sign in to comment.