Skip to content

Commit

Permalink
Set seed.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Oct 6, 2023
1 parent 242d4ee commit cc33e99
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
3 changes: 3 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ hydra:
job:
chdir: True # change to output folder


seed: 0

# Dataset
files:
dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
Expand Down
13 changes: 10 additions & 3 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ def prep_trainable_mask(config, psf, grayscale=False):
@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM")
def train_unrolled(config):

# set seed
seed = config.seed
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
generator = torch.Generator().manual_seed(seed)

save = config.save
if save:
save = os.getcwd()
Expand Down Expand Up @@ -306,7 +313,9 @@ def train_unrolled(config):
# train-test split
train_size = int((1 - config.files.test_size) * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])
train_set, test_set = torch.utils.data.random_split(
dataset, [train_size, test_size], generator=generator
)
if config.files.n_files is not None:
train_set = Subset(train_set, np.arange(config.files.n_files))
test_set = Subset(test_set, np.arange(config.files.n_files))
Expand Down Expand Up @@ -334,8 +343,6 @@ def train_unrolled(config):
log.info(f"Train test size : {len(train_set)}")
log.info(f"Test test size : {len(test_set)}")

raise ValueError

start_time = time.time()

# Load pre process model
Expand Down

0 comments on commit cc33e99

Please sign in to comment.