Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arbitrary num_diffn_samples + use ds to enumerate diffusion samples #260

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def run_inference(
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
num_diffn_samples: int = 5,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
Expand Down Expand Up @@ -470,6 +471,7 @@ def run_inference(
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
num_diffn_samples=num_diffn_samples,
seed=seed,
device=torch_device,
low_memory=low_memory,
Expand All @@ -488,6 +490,8 @@ def run_folding_on_context(
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
# all diffusion samples come from the same trunk
num_diffn_samples: int = 5,
seed: int | None = None,
device: torch.device | None = None,
low_memory: bool,
Expand Down Expand Up @@ -678,19 +682,19 @@ def run_folding_on_context(
static_diffusion_inputs, device=device
)

def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
# verified manually that ds dimension can be arbitrary in diff module
atom_noised_coords = rearrange(
atom_pos, "(b s) ... -> b s ...", s=s
atom_pos, "(b ds) ... -> b ds ...", ds=ds
).contiguous()
noise_sigma = repeat(sigma, " -> b s", b=batch_size, s=s)
noise_sigma = repeat(sigma, " -> b ds", b=batch_size, ds=ds)
return diffusion_module.forward(
atom_noised_coords=atom_noised_coords.float(),
noise_sigma=noise_sigma.float(),
crop_size=model_size,
**static_diffusion_inputs,
)

num_diffn_samples = 5 # Fixed at export time
inference_noise_schedule = InferenceNoiseSchedule(
s_max=DiffusionConfig.S_tmax,
s_min=4e-4,
Expand Down Expand Up @@ -722,8 +726,8 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
atom_pos,
atom_single_mask=repeat(
atom_single_mask,
"b a -> (b s) a",
s=num_diffn_samples,
"b a -> (b ds) a",
ds=num_diffn_samples,
),
)

Expand All @@ -739,7 +743,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
denoised_pos = _denoise(
atom_pos=atom_pos_hat,
sigma=sigma_hat,
s=num_diffn_samples,
ds=num_diffn_samples,
)
d_i = (atom_pos_hat - denoised_pos) / sigma_hat
atom_pos = atom_pos_hat + (sigma_next - sigma_hat) * d_i
Expand All @@ -749,7 +753,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
denoised_pos = _denoise(
atom_pos,
sigma=sigma_next,
s=num_diffn_samples,
ds=num_diffn_samples,
)
d_i_prime = (atom_pos - denoised_pos) / sigma_next
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)
Expand All @@ -770,13 +774,13 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
token_pair_trunk_repr=token_pair_trunk_repr,
token_single_mask=token_single_mask,
atom_single_mask=atom_single_mask,
atom_coords=atom_pos[s : s + 1],
atom_coords=atom_pos[ds : ds + 1],
token_reference_atom_index=token_reference_atom_index,
atom_token_index=atom_token_indices,
atom_within_token_index=atom_within_token_index,
crop_size=model_size,
)
for s in range(num_diffn_samples)
for ds in range(num_diffn_samples)
]

pae_logits, pde_logits, plddt_logits = [
Expand Down
Loading