Skip to content

Commit

Permalink
expose number of diffusion samples, use ds as axis for diffusion samples
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Dec 24, 2024
1 parent 275b7f2 commit 5dc0f31
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def run_inference(
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
num_diffn_samples: int = 5,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
Expand Down Expand Up @@ -470,6 +471,7 @@ def run_inference(
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
num_diffn_samples=num_diffn_samples,
seed=seed,
device=torch_device,
low_memory=low_memory,
Expand All @@ -488,6 +490,8 @@ def run_folding_on_context(
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
# all diffusion samples come from the same trunk
num_diffn_samples: int = 5,
seed: int | None = None,
device: torch.device | None = None,
low_memory: bool,
Expand Down Expand Up @@ -679,6 +683,7 @@ def run_folding_on_context(
)

def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
# verified manually that ds dimension can be arbitrary in diff module
atom_noised_coords = rearrange(
atom_pos, "(b ds) ... -> b ds ...", ds=ds
).contiguous()
Expand All @@ -690,7 +695,6 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
**static_diffusion_inputs,
)

num_diffn_samples = 5 # Fixed at export time
inference_noise_schedule = InferenceNoiseSchedule(
s_max=DiffusionConfig.S_tmax,
s_min=4e-4,
Expand Down

0 comments on commit 5dc0f31

Please sign in to comment.