Skip to content

Commit

Permalink
Batch-size minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
YinuoJin committed Nov 19, 2023
1 parent af6aab9 commit 9710c90
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 3 additions & 3 deletions starfysh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def init_weights(module):
def run_starfysh(
visium_args,
n_repeats=3,
lr=1e-3,
lr=1e-4,
epochs=100,
batch_size=32,
alpha_mul=50,
Expand Down Expand Up @@ -416,7 +416,7 @@ def run_starfysh(
train_func = train

trainset = dl_func(adata=adata, args=visium_args)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)

# Running Starfysh with multiple starts
LOGGER.info('Running Starfysh with {} restarts, choose the model with best parameters...'.format(n_repeats))
Expand All @@ -440,7 +440,7 @@ def run_starfysh(
adata=adata,
gene_sig=sig_mean_norm,
win_loglib=win_loglib,
alpha_mul=alpha_mul,
alpha_mul=alpha_mul
)

model = model.to(device)
Expand Down
8 changes: 3 additions & 5 deletions starfysh/utils_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def init_weights(module):
def run_starfysh(
visium_args,
n_repeats=3,
lr=1e-3,
lr=1e-4,
epochs=100,
batch_size=32,
alpha_mul=50,
Expand Down Expand Up @@ -429,7 +429,7 @@ def run_starfysh(
train_func = train

trainset = dl_func(adata=adata, args=visium_args)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)

# Running Starfysh with multiple starts
LOGGER.info('Running Starfysh with {} restarts, choose the model with best parameters...'.format(n_repeats))
Expand All @@ -447,7 +447,6 @@ def run_starfysh(
win_loglib=win_loglib,
alpha_mul=alpha_mul,
n_img_chan=1
#batch_size=batch_size,
)
# Update patched & flattened image patches
visium_args._update_img_patches(trainset)
Expand All @@ -456,8 +455,7 @@ def run_starfysh(
adata=adata,
gene_sig=sig_mean_norm,
win_loglib=win_loglib,
alpha_mul=alpha_mul,
batch_size=batch_size
alpha_mul=alpha_mul
)

model = model.to(device)
Expand Down

0 comments on commit 9710c90

Please sign in to comment.