Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 26, 2023
2 parents 44c92f6 + c14e8cb commit a0653be
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 1 addition & 3 deletions cca_zoo/deep/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def __init__(
nn.Dropout(p=dropout), nn.Linear(linear_input_size, latent_dimensions)
)

def _build_conv_layers(
self, channels, kernel_sizes, strides, paddings, activation
):
def _build_conv_layers(self, channels, kernel_sizes, strides, paddings, activation):
layers = []
current_channels = 1
for idx in range(len(channels)):
Expand Down
8 changes: 6 additions & 2 deletions cca_zoo/linear/_gradient/_ey.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ def validation_step(self, batch, batch_idx):
return loss["objective"]

def get_dataset(self, views: Iterable[np.ndarray], validation_views=None):
dataset = DoubleNumpyDataset(views, batch_size=self.batch_size, random_state=self.random_state)
dataset = DoubleNumpyDataset(
views, batch_size=self.batch_size, random_state=self.random_state
)
if validation_views is not None:
val_dataset = DoubleNumpyDataset(validation_views, self.batch_size, self.random_state)
val_dataset = DoubleNumpyDataset(
validation_views, self.batch_size, self.random_state
)
else:
val_dataset = None
return dataset, val_dataset
Expand Down

0 comments on commit a0653be

Please sign in to comment.