From f79bacb8cfc54ad2ad87f3fdca0107b3b716cfb3 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 13 Feb 2024 14:25:37 -0500 Subject: [PATCH] Renaming outputs. Improving titles --- dwi_ml/testing/projects/tt_visu_argparser.py | 14 +- .../testing/projects/tt_visu_colored_sft.py | 149 ++++++++---------- dwi_ml/testing/projects/tt_visu_main.py | 16 +- dwi_ml/testing/projects/tt_visu_matrix.py | 42 +++-- dwi_ml/testing/projects/tt_visu_utils.py | 115 +++++++------- scripts_python/tests/test_all_steps_tto.py | 5 +- scripts_python/tests/test_all_steps_tts.py | 5 +- scripts_python/tests/test_all_steps_ttst.py | 5 +- 8 files changed, 180 insertions(+), 171 deletions(-) diff --git a/dwi_ml/testing/projects/tt_visu_argparser.py b/dwi_ml/testing/projects/tt_visu_argparser.py index db4e0d77..8b4c1842 100644 --- a/dwi_ml/testing/projects/tt_visu_argparser.py +++ b/dwi_ml/testing/projects/tt_visu_argparser.py @@ -28,11 +28,21 @@ 4) 'color_x_y_summary': Saves two colored tractogram: - Projection on x: This is a measure of the importance of each point on the streamline. + - x_mean_att* = mean attention strength at this point. + - x_nb_usage* = percentage of points using current point + importantly. - Projection on y: This is an indication of where we were looking at when deciding the next direction at each point on the streamline. + - y_looked_far: average index (in percentage of length) of points + used importantly. 0 = current point. 1/length = Looked at + preceding point. 1 = looked very far. + - y_nb_looked: number of points (in % of length) that were + important. + - y_max_pos: index (in % of length) of the maximal point. + 0 = current point. 1/length = Looked at preceding point. + 1 = looked very far. - The projection technique depends on the chosen rescaling options. - - Outputs: _encoder_colored_importance_*.trk _encoder_colored_looked_far*.trk 5) 'bertviz_locally': Run the bertviz without using jupyter. (Debugging purposes. Output will not show, but html stuff will print @@ -87,7 +97,7 @@ def build_argparser_transformer_visu(): gg =g.add_mutually_exclusive_group() gg.add_argument('--bertviz', action='store_true', help="See description above.") - gg.add_argument('--bertviz_locally', + gg.add_argument('--bertviz_locally', action='store_true', help="See description above.") g = p.add_argument_group("Saving options") diff --git a/dwi_ml/testing/projects/tt_visu_colored_sft.py b/dwi_ml/testing/projects/tt_visu_colored_sft.py index 8294b3b1..a1059b80 100644 --- a/dwi_ml/testing/projects/tt_visu_colored_sft.py +++ b/dwi_ml/testing/projects/tt_visu_colored_sft.py @@ -122,8 +122,11 @@ def color_sft_duplicate_lines( colored_sft = whole_sft.from_sft( whole_sft.streamlines, whole_sft, data_per_point={key: whole_sft.data_per_point[key]}) - colored_sft, cbar_fig = _color_sft_from_dpp(colored_sft, key, - title=explanation) + + # Not using a fixed vmax, vmin. Uses the bundle's data. + # Easier to view. + colored_sft, cbar_fig = _color_sft_from_dpp( + colored_sft, key, prepare_fig=True, title=explanation) filename_prefix = prefix_name + '_colored_multi_length_' + key filename_trk = filename_prefix + '.trk' @@ -135,14 +138,26 @@ def color_sft_duplicate_lines( plt.savefig(filename_cbar) -def color_sft_importance_looked_far( - sft: StatefulTractogram, lengths, prefix_name: str, +def color_sft_x_y_projections( + sft: StatefulTractogram, prefix_name: str, attentions_per_line: list, attention_names: Tuple, average_heads, average_layers, group_with_max, rescale_0_1, rescale_non_lin, rescale_z, explanation): - (options_main, options_importance, options_range_length, - explanation_part2, rescale_name) = get_visu_params_from_options( - rescale_0_1, rescale_non_lin, rescale_z, max(lengths), max(lengths)) + """ + Saves one tractogram per "projection": + + Proj on x: + - nb_usage + - mean_att + + Proj on y: + - looked_far + - max_pos + - nb_looked + """ + (options_main, options_range_length, explanation_part2, + rescale_name, thresh) = get_visu_params_from_options( + rescale_0_1, rescale_non_lin, rescale_z) explanation += '\n' + explanation_part2 @@ -168,93 +183,63 @@ def color_sft_importance_looked_far( else: head_suffix = '_h{}'.format(head) + all_nb_usage = [] + all_mean_att = [] all_looked_far = [] - all_importance = [] all_maxp = [] all_nb_looked = [] for s in range(len(sft.streamlines)): a = att_type[s][layer][head, :, :] - a, looked_far, importance, max_p, nb_lookedd = \ + a, mean_att, nb_usage, looked_far, max_p, nb_looked = \ prepare_projections_from_options( a, rescale_0_1, rescale_non_lin, rescale_z) - all_importance.append(importance[:, None]) + all_mean_att.append(mean_att[:, None]) + all_nb_usage.append(nb_usage[:, None]) all_looked_far.append(looked_far[:, None]) all_maxp.append(max_p[:, None]) - all_nb_looked.append(nb_lookedd[:, None]) + all_nb_looked.append(nb_looked[:, None]) # Save results for this attention, head, layer filename_prefix = prefix_name + attention_names[i] + \ layer_prefix + head_suffix - # 1) IMPORTANCE - sft.data_per_point['importance'] = all_importance - _color_sft_from_dpp(sft, 'importance', - options_importance['cmap'], - options_importance['vmin'], - options_importance['vmax'], - title=explanation) - filename_trk = filename_prefix + '_importance.trk' + # Mean att: not fixing the vmin, vmax. + name = 'x_mean_att' + sft.data_per_point[name] = all_mean_att + _color_sft_from_dpp(sft, name, **options_range_length) + filename_trk = filename_prefix + '_' + name + '.trk' print("Saving {} with dpp {}" .format(filename_trk, list(sft.data_per_point.keys()))) save_tractogram(sft, filename_trk, bbox_valid_check=False) - # plt.savefig(filename_cbar) - del sft.data_per_point['importance'] + del sft.data_per_point[name] del sft.data_per_point['color'] - # 2) LOOKED_FAR (mean pos where looked) - sft.data_per_point['looked_far'] = all_looked_far - _color_sft_from_dpp(sft, 'looked_far', - options_range_length['cmap'], - options_range_length['vmin'], - options_range_length['vmax'], - title=explanation) - filename_trk = filename_prefix + '_looked_far.trk' - print("Saving {} with dpp {}" - .format(filename_trk, list(sft.data_per_point.keys()))) - save_tractogram(sft, filename_trk, bbox_valid_check=False) - # plt.savefig(filename_cbar) - del sft.data_per_point['looked_far'] - del sft.data_per_point['color'] - - # 3) MAX_POS (0 = looked at current point) - sft.data_per_point['max_pos'] = all_maxp - _color_sft_from_dpp(sft, 'max_pos', - options_range_length['cmap'], - options_range_length['vmin'], - options_range_length['vmax'], - title=explanation) - filename_trk = filename_prefix + '_max_position.trk' - print("Saving {} with dpp {}" - .format(filename_trk, list(sft.data_per_point.keys()))) - save_tractogram(sft, filename_trk, bbox_valid_check=False) - # plt.savefig(filename_cbar) - del sft.data_per_point['max_pos'] - del sft.data_per_point['color'] - - # 4) NB LOOKED - sft.data_per_point['nb_looked'] = all_nb_looked - _color_sft_from_dpp(sft, 'nb_looked', - options_range_length['cmap'], - options_range_length['vmin'], - options_range_length['vmax'], - title=explanation) - filename_trk = filename_prefix + '_nb_looked.trk' - print("Saving {} with dpp {}" - .format(filename_trk, list(sft.data_per_point.keys()))) - save_tractogram(sft, filename_trk, bbox_valid_check=False) - # plt.savefig(filename_cbar) - del sft.data_per_point['nb_looked'] - del sft.data_per_point['color'] - - -def _color_sft_from_dpp(sft, key, map_name='viridis', mmin=None, mmax=None, - title=None): - cmap = get_colormap(map_name) + # Others: Range is 0 - 1 (where 1 = 100% of streamline length) + # Not saving the colorbar. + names = ('x_nb_usage', + 'y_looked_far', 'y_max_pos', 'y_nb_looked') + data = (all_nb_usage, + all_looked_far, all_maxp, all_nb_looked) + for name, vectors in zip(names, data): + sft.data_per_point[name] = vectors + _color_sft_from_dpp(sft, name, **options_range_length) + filename_trk = filename_prefix + '_' + name + '.trk' + print("Saving {} with dpp {}" + .format(filename_trk, + list(sft.data_per_point.keys()))) + save_tractogram(sft, filename_trk, bbox_valid_check=False) + del sft.data_per_point[name] + del sft.data_per_point['color'] + + +def _color_sft_from_dpp(sft, key, cmap='viridis', vmin=None, vmax=None, + prepare_fig: bool = False, title=None, **kw): + cmap = get_colormap(cmap) tmp = [np.squeeze(sft.data_per_point[key][s]) for s in range(len(sft))] data = np.hstack(tmp) - mmin = mmin or np.min(data) - mmax = mmax or np.max(data) + vmin = vmin or np.min(data) + vmax = vmax or np.max(data) data = data - np.min(data) data = data / np.max(data) color = cmap(data)[:, 0:3] * 255 @@ -262,14 +247,16 @@ def _color_sft_from_dpp(sft, key, map_name='viridis', mmin=None, mmax=None, sft.data_per_point['color']._data = color # Preparing a figure - fig = plt.figure(figsize=(9, 3)) - plt.imshow(np.array([[1, 0]]), cmap=cmap, vmin=mmin, vmax=mmax) - plt.gca().set_visible(False) - cax = plt.axes([0.1, 0.1, 0.6, 0.85]) - plt.colorbar(orientation="horizontal", cax=cax, aspect=0.01) - if title is not None: - plt.title("Colorbar for key: {}\n".format(key) + title) - else: - plt.title("Colorbar for key: {}".format(key)) + fig = None + if prepare_fig: + fig = plt.figure(figsize=(9, 3)) + plt.imshow(np.array([[1, 0]]), cmap=cmap, vmin=vmin, vmax=vmax) + plt.gca().set_visible(False) + cax = plt.axes([0.1, 0.1, 0.6, 0.85]) + plt.colorbar(orientation="horizontal", cax=cax, aspect=0.01) + if title is not None: + plt.title("Colorbar for key: {}\n".format(key) + title) + else: + plt.title("Colorbar for key: {}".format(key)) return sft, fig diff --git a/dwi_ml/testing/projects/tt_visu_main.py b/dwi_ml/testing/projects/tt_visu_main.py index c6619b42..c01a5484 100644 --- a/dwi_ml/testing/projects/tt_visu_main.py +++ b/dwi_ml/testing/projects/tt_visu_main.py @@ -24,7 +24,7 @@ encoder_decoder_show_head_view, encoder_decoder_show_model_view, encoder_show_model_view, encoder_show_head_view) from dwi_ml.testing.projects.tt_visu_colored_sft import ( - color_sft_duplicate_lines, color_sft_importance_looked_far) + color_sft_duplicate_lines, color_sft_x_y_projections) from dwi_ml.testing.projects.tt_visu_matrix import show_model_view_as_imshow from dwi_ml.testing.projects.tt_visu_utils import ( prepare_encoder_tokens, prepare_decoder_tokens, @@ -201,13 +201,13 @@ def _visu_encoder_decoder( lengths, args.rescale_0_1, args.rescale_z, args.rescale_non_lin) if has_decoder: - attention_names = ('encoder',) - else: attention_names = ('encoder', 'decoder', 'cross') + else: + attention_names = ('encoder',) if args.color_multi_length: print( - "\n\n-------------- Preparing the colors for each length of " + "\n-------------- Preparing the colors for each length of " "each streamline --------------") color_sft_duplicate_lines(sft, lengths, prefix_name, weights, attention_names, average_heads, @@ -216,10 +216,10 @@ def _visu_encoder_decoder( if args.color_x_y_summary: print( - "\n\n-------------- Preparing the colors summary (importance, " - "where looked) for each streamline --------------") - color_sft_importance_looked_far( - sft, lengths, prefix_name, weights, attention_names, + "\n-------------- Preparing the colors summary (nb_usage, " + "where looked, etc) for each streamline --------------") + color_sft_x_y_projections( + sft, prefix_name, weights, attention_names, average_heads, average_layers, args.group_with_max, args.rescale_0_1, args.rescale_non_lin, args.rescale_z, explanation) diff --git a/dwi_ml/testing/projects/tt_visu_matrix.py b/dwi_ml/testing/projects/tt_visu_matrix.py index 288c0f2d..cdd5ea8a 100644 --- a/dwi_ml/testing/projects/tt_visu_matrix.py +++ b/dwi_ml/testing/projects/tt_visu_matrix.py @@ -17,9 +17,9 @@ def show_model_view_as_imshow( size_x = len(tokens_x) size_y = len(tokens_y) - (options_main, options_importance, options_range_length, - explanation, rescale_name) = get_visu_params_from_options( - rescale_0_1, rescale_non_lin, rescale_z, size_x, size_y) + (options_main, options_range_length, explanation, rescale_name, + thresh) = get_visu_params_from_options( + rescale_0_1, rescale_non_lin, rescale_z) for i in range(nb_layers): att = attention_one_line[i] @@ -30,11 +30,12 @@ def show_model_view_as_imshow( axs = [axs] for h in range(nb_heads): - a, where_looked, importance, maxp, nb_looked = \ + a, mean_att, importance, looked_far, max_pos, nb_looked = \ prepare_projections_from_options( att[h, :, :], rescale_0_1, rescale_non_lin, rescale_z) divider = make_axes_locatable(axs[h]) + ax_mean_att = divider.append_axes("bottom", size=0.2, pad=0) ax_importance = divider.append_axes("bottom", size=0.2, pad=0) ax_lookedfar = divider.append_axes("right", size=0.2, pad=0) ax_max = divider.append_axes("right", size=0.2, pad=0) @@ -42,19 +43,21 @@ def show_model_view_as_imshow( ax_cbar_main = divider.append_axes("right", size=0.3, pad=0.3) ax_cbar_length = divider.append_axes("right", size=0.3, pad=0.55) + # Plot the main image + im_main = axs[h].imshow(a, **options_main) + # Bottom and right images + _ = ax_mean_att.imshow(mean_att[None, :], + **options_main, aspect='auto') im_b = ax_importance.imshow(importance[None, :], - **options_importance, aspect='auto') - _ = ax_lookedfar.imshow(where_looked[:, None], + **options_range_length, aspect='auto') + _ = ax_lookedfar.imshow(looked_far[:, None], **options_range_length, aspect='auto') - _ = ax_max.imshow(maxp[:, None], + _ = ax_max.imshow(max_pos[:, None], **options_range_length, aspect='auto') _ = ax_nb_looked.imshow(nb_looked[:, None], **options_range_length, aspect='auto') - # Plot the main image - im_main = axs[h].imshow(a, **options_main) - # Set the titles (see also suptitle below) if average_heads: if group_with_max: @@ -65,30 +68,39 @@ def show_model_view_as_imshow( .format(rescale_name)) else: axs[h].set_title("Head {}".format(h)) + + # Titles proj X + ax_mean_att.set_ylabel("Mean", rotation=0, labelpad=25) + ax_importance.set_ylabel("Importance.", rotation=0, labelpad=25) + + # Titles proj Y ax_lookedfar.set_title("Looked far", rotation=45, loc='left') ax_max.set_title("Max pos", rotation=45, loc='left') ax_nb_looked.set_title("Nb looked", rotation=45, loc='left') - ax_importance.set_ylabel("Importance.", rotation=0, labelpad=25) # ("Importance" is a bit too close to last tick. Tried to use # loc='bottom' but then ignores labelpad). - # Set the ticks with tokens. + # Main image: set the ticks with tokens. axs[h].set_xticks(np.arange(size_x), fontsize=10) axs[h].set_yticks(np.arange(size_y), fontsize=10) - axs[h].tick_params(axis='x', pad=20) axs[h].set_xticklabels(tokens_x, rotation=-90) axs[h].set_yticklabels(tokens_y) + # Move x ticks under the projections + axs[h].tick_params(axis='x', pad=40) + # Other plots: Hide ticks. - for ax in [ax_importance, ax_lookedfar, ax_max, ax_nb_looked]: + for ax in [ax_mean_att, ax_importance, + ax_lookedfar, ax_max, ax_nb_looked]: plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) # Set the colorbars, with titles. + # ToDo. Colorbar mean_att. fig.colorbar(im_main, cax=ax_cbar_main) ax_cbar_main.set_ylabel('Main figure', rotation=90, labelpad=-55) fig.colorbar(im_b, cax=ax_cbar_length) - ax_cbar_length.set_ylabel('x / y projections: [0, length]', + ax_cbar_length.set_ylabel('x / y projections: %% of length', rotation=90, labelpad=-55) if average_layers: diff --git a/dwi_ml/testing/projects/tt_visu_utils.py b/dwi_ml/testing/projects/tt_visu_utils.py index cf49ff52..18b10134 100644 --- a/dwi_ml/testing/projects/tt_visu_utils.py +++ b/dwi_ml/testing/projects/tt_visu_utils.py @@ -10,6 +10,14 @@ from scilpy.io.fetcher import get_home as get_scilpy_folder +THRESH_IMPORTANT = { + 'rescale_0_1': 0.75, # No idea... 1 = the most important. + 'rescale_non_lin': 0.5, # 0.5 = Value when all equal. + 'rescale_z': 1.96, # The 95 percent confidence level with p<0.05 + 'None': 0.5 # NOT SIGNIFICATIVE AT ALL. +} + + def reshape_unpad_rescale_attention( attention_per_layer, average_heads: bool, average_layers, group_with_max, lengths, rescale_0_1, rescale_z, rescale_non_lin): @@ -198,98 +206,87 @@ def prepare_encoder_tokens(this_seq_len, step_size, add_eos: bool): return encoder_tokens -def get_visu_params_from_options( - rescale_0_1, rescale_non_lin, rescale_z, size_x, size_y): +def get_visu_params_from_options(rescale_0_1, rescale_non_lin, rescale_z): """ Defines options for prefix names, colormaps, vmin, vmax, explanation text, etc. """ - vmin_main, vmax_main, cmap_main = (0, 1, 'viridis') - vmin_importance, vmax_importance, cmap_importance = (0, 1, 'viridis') - vmin_pos, vmax_pos, cmap_pos = (0, size_x, 'viridis') - if rescale_0_1 or rescale_non_lin: - if rescale_0_1: - explanation = ( - 'Importance: Number of times that this point was ' - 'very important (> 0.75).\n' - "Looked far: Where the important points (>0.75) to decide " - "next direction are situated." - "0 = current point. Max = very far behind.") - rescale_name = 'rescale_0_1' - else: - explanation = ( - "Importance: Number of times that this point was more " - "important than the average >0.5.\n" - "Looked far: Where the important points (>0.5) to decide next " - "direction are situated. " - "0 = current point. Max = very far behind.") - rescale_name = 'rescale_non_lin' - cmap_main = 'coolwarm' # See also turbo, rainbow - vmax_importance = size_y - cmap_importance = 'plasma' # See also rainbow, inferno - cmap_pos = 'plasma' + vmin_main, vmax_main, cmap_main = (0, 1, 'turbo') + vmin_pos, vmax_pos, cmap_pos = (0, 1, 'CMRmap') + if rescale_0_1: + rescale_name = 'rescale_0_1' + elif rescale_non_lin: + rescale_name = 'rescale_non_lin' + # cmap_main = 'coolwarm' elif rescale_z: - explanation = "Bottom row: Average." rescale_name = 'rescale_z' - vmin_main = None - vmax_main = None - vmax_importance = None - vmin_importance = None + # Range: We could limit it to help view better. Ex: ±3 std. + vmin_main = -3 + vmax_main = 3 else: - explanation = 'Bottom row: Average.' - rescale_name = '' + rescale_name = 'None' + + thresh = THRESH_IMPORTANT[rescale_name] + explanation = ( + 'Importance: Number of times that this point was very important ' + '(>{:.2f}).\n' + "Looked far: Mean index of the important points (>{:.2f}) to decide " + "the next direction. 0 = current point. 100%% = very far behind.\n" + "Max_pos: Index of the point of maximal attention.\n" + "Nb_looked: Number of points of important attention." + .format(thresh, thresh)) options_main = {'interpolation': 'None', 'cmap': cmap_main, 'vmin': vmin_main, 'vmax': vmax_main} - options_importance = {'interpolation': 'None', - 'cmap': cmap_importance, - 'vmin': vmin_importance, - 'vmax': vmax_importance} - options_position = {'interpolation': 'None', 'cmap': cmap_pos, 'vmin': vmin_pos, 'vmax': vmax_pos} - return (options_main, options_importance, options_position, - explanation, rescale_name) + return options_main, options_position, explanation, rescale_name, thresh def prepare_projections_from_options(a, rescale_0_1, rescale_non_lin, rescale_z): a = np.squeeze(a) a = np.ma.masked_where(a == 0, a) - if rescale_0_1 or rescale_non_lin: - if rescale_0_1: - thresh = 0.75 - else: - thresh = 0.5 + if rescale_0_1: + rescale_name = 'rescale_0_1' + elif rescale_non_lin: + rescale_name = 'rescale_non_lin' + elif rescale_z: + rescale_name = 'rescale_z' + else: + rescale_name = 'None' + thresh = THRESH_IMPORTANT[rescale_name] - # Importance = nb of points > thresh as x projection - importance = np.sum(a > thresh, axis=0) + length = float(a.shape[1]) + flipped_range = np.flip(np.arange(1, a.shape[1] + 1)) - # Nb looked = nb of points > thresh as y projection - nb_looked = np.sum(a > thresh, axis=1) + # Mean = masked mean. + mean_att = np.sum(a, axis=0) / flipped_range - # Looked far = mean index of points where > thresh. - indexes = np.arange(1, a.shape[1] + 1) - indexes = np.abs(indexes[None, :] - indexes[:, None]) - indexes = np.ma.masked_where(~(a > thresh), indexes) - looked_far = np.mean((a > thresh) * indexes, axis=1) - # looked_far /= np.flip(np.arange(1, a.shape[1] + 1)) + # Importance = nb of points > thresh as x projection + importance = np.sum(a > thresh, axis=0) / length - else: - importance = np.mean(a, axis=0) - raise NotImplemented("Where looked not defined.") + # Looked far = mean index of points where > thresh. + indexes = np.arange(1, a.shape[1] + 1) + indexes = np.abs(indexes[None, :] - indexes[:, None]) + indexes = np.ma.masked_where(~(a > thresh), indexes) + looked_far = np.mean((a > thresh) * indexes, axis=1) / length # Position of maximal point max_pos = np.argmax(a, axis=1) + 1 max_pos = np.arange(1, a.shape[1] + 1) - max_pos + max_pos = max_pos / length + + # Nb looked = nb of points > thresh as y projection + nb_looked = np.sum(a > thresh, axis=1) / length - return a, looked_far, importance, max_pos, nb_looked + return a, mean_att, importance, looked_far, max_pos, nb_looked def get_config_filename(): diff --git a/scripts_python/tests/test_all_steps_tto.py b/scripts_python/tests/test_all_steps_tto.py index e249bed9..8b78d03c 100644 --- a/scripts_python/tests/test_all_steps_tto.py +++ b/scripts_python/tests/test_all_steps_tto.py @@ -130,8 +130,9 @@ def test_execution(script_runner, experiments_path): ret = script_runner.run( 'tt_visualize_weights.py', whole_experiment_path, hdf5_file, subj_id, input_group, in_sft, '--out_prefix', prefix, - '--visu_type', 'as_matrices', 'colored_sft', 'bertviz_locally', + '--as_matrices', '--color_multi_length', '--color_x_y_summary', + '--bertviz_locally', '--subset', 'training', '--logging', 'INFO', - '--resample_plots', '15', '--rescale') + '--resample_plots', '15', '--rescale_non_lin') assert ret.success diff --git a/scripts_python/tests/test_all_steps_tts.py b/scripts_python/tests/test_all_steps_tts.py index d1d1d8fb..c44b972d 100644 --- a/scripts_python/tests/test_all_steps_tts.py +++ b/scripts_python/tests/test_all_steps_tts.py @@ -98,9 +98,10 @@ def test_execution(script_runner, experiments_path): ret = script_runner.run( 'tt_visualize_weights.py', whole_experiment_path, hdf5_file, subj_id, input_group, in_sft, '--out_prefix', prefix, - '--visu_type', 'as_matrices', 'colored_sft', 'bertviz_locally', + '--as_matrices', '--color_multi_length', '--color_x_y_summary', + '--bertviz_locally', '--subset', 'training', '--logging', 'INFO', - '--resample_plots', '15', '--rescale') + '--resample_plots', '15', '--rescale_0') assert ret.success diff --git a/scripts_python/tests/test_all_steps_ttst.py b/scripts_python/tests/test_all_steps_ttst.py index 9827c953..252dd897 100644 --- a/scripts_python/tests/test_all_steps_ttst.py +++ b/scripts_python/tests/test_all_steps_ttst.py @@ -100,7 +100,8 @@ def test_execution(script_runner, experiments_path): ret = script_runner.run( 'tt_visualize_weights.py', whole_experiment_path, hdf5_file, subj_id, input_group, in_sft, '--out_prefix', prefix, - '--visu_type', 'as_matrices', 'colored_sft', 'bertviz_locally', + '--as_matrices', '--color_multi_length', '--color_x_y_summary', + '--bertviz_locally', '--subset', 'training', '--logging', 'INFO', - '--resample_plots', '15', '--rescale') + '--resample_plots', '15', '--rescale_z') assert ret.success