Skip to content

Commit

Permalink
Big update
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 25, 2023
1 parent 2c7c9e2 commit 9703684
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 0 additions & 2 deletions cca_zoo/visualisation/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,12 @@ def plot(self):
sns.heatmap(
self.train_covariances,
annot=True,
cmap="coolwarm",
ax=axs[0],
)
if self.test_covariances is not None:
sns.heatmap(
self.test_covariances,
annot=True,
cmap="coolwarm",
ax=axs[1],
)
axs[0].set_title("Train Covariances")
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/visualisation/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _create_plot(self, x, y, hue, alpha=None, palette=None):
ax,
)

def plot(self, title=None):
def plot(self, title=""):
dimensions = self.scores[0].shape[1]
self.figures_ = []

Expand Down Expand Up @@ -226,7 +226,7 @@ def _create_plot(self, x, y, hue=None, palette=None):


class SeparateScoreScatterDisplay(ScoreScatterDisplay):
def plot(self, title=None):
def plot(self, title=""):
dimensions = self.scores[0].shape[1]
self.train_figures_ = []
self.test_figures_ = []
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/visualisation/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def plot(self, **kwargs):
self.weights_cov = [w.T @ w for w in self.weights]
# loop through each view and have a heatmap of the covariance of the weights_
for i, view_weights_cov in enumerate(self.weights_cov):
sns.heatmap(view_weights_cov, ax=axs[i], **self.kwargs)
sns.heatmap(view_weights_cov, ax=axs[i],annot=True, **self.kwargs)
axs[i].set_title(self.view_labels[i])
plt.tight_layout()
self.figure_ = fig
Expand Down

0 comments on commit 9703684

Please sign in to comment.