Skip to content

Commit

Permalink
Renaming outputs. Improving titles
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Apr 9, 2024
1 parent 37ec5c4 commit f79bacb
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 171 deletions.
14 changes: 12 additions & 2 deletions dwi_ml/testing/projects/tt_visu_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,21 @@
4) 'color_x_y_summary': Saves two colored tractogram:
- Projection on x: This is a measure of the importance of each
point on the streamline.
- x_mean_att* = mean attention strength at this point.
- x_nb_usage* = percentage of points using current point
importantly.
- Projection on y: This is an indication of where we were looking
at when deciding the next direction at each point on the
streamline.
- y_looked_far: average index (in percentage of length) of points
used importantly. 0 = current point. 1/length = Looked at
preceding point. 1 = looked very far.
- y_nb_looked: number of points (in % of length) that were
important.
- y_max_pos: index (in % of length) of the maximal point.
0 = current point. 1/length = Looked at preceding point.
1 = looked very far.
- The projection technique depends on the chosen rescaling options.
- Outputs: _encoder_colored_importance_*.trk
_encoder_colored_looked_far*.trk
5) 'bertviz_locally': Run the bertviz without using jupyter.
(Debugging purposes. Output will not show, but html stuff will print
Expand Down Expand Up @@ -87,7 +97,7 @@ def build_argparser_transformer_visu():
gg =g.add_mutually_exclusive_group()
gg.add_argument('--bertviz', action='store_true',
help="See description above.")
gg.add_argument('--bertviz_locally',
gg.add_argument('--bertviz_locally', action='store_true',
help="See description above.")

g = p.add_argument_group("Saving options")
Expand Down
149 changes: 68 additions & 81 deletions dwi_ml/testing/projects/tt_visu_colored_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def color_sft_duplicate_lines(
colored_sft = whole_sft.from_sft(
whole_sft.streamlines, whole_sft,
data_per_point={key: whole_sft.data_per_point[key]})
colored_sft, cbar_fig = _color_sft_from_dpp(colored_sft, key,
title=explanation)

# Not using a fixed vmax, vmin. Uses the bundle's data.
# Easier to view.
colored_sft, cbar_fig = _color_sft_from_dpp(
colored_sft, key, prepare_fig=True, title=explanation)

filename_prefix = prefix_name + '_colored_multi_length_' + key
filename_trk = filename_prefix + '.trk'
Expand All @@ -135,14 +138,26 @@ def color_sft_duplicate_lines(
plt.savefig(filename_cbar)


def color_sft_importance_looked_far(
sft: StatefulTractogram, lengths, prefix_name: str,
def color_sft_x_y_projections(
sft: StatefulTractogram, prefix_name: str,
attentions_per_line: list, attention_names: Tuple,
average_heads, average_layers, group_with_max,
rescale_0_1, rescale_non_lin, rescale_z, explanation):
(options_main, options_importance, options_range_length,
explanation_part2, rescale_name) = get_visu_params_from_options(
rescale_0_1, rescale_non_lin, rescale_z, max(lengths), max(lengths))
"""
Saves one tractogram per "projection":
Proj on x:
- nb_usage
- mean_att
Proj on y:
- looked_far
- max_pos
- nb_looked
"""
(options_main, options_range_length, explanation_part2,
rescale_name, thresh) = get_visu_params_from_options(
rescale_0_1, rescale_non_lin, rescale_z)

explanation += '\n' + explanation_part2

Expand All @@ -168,108 +183,80 @@ def color_sft_importance_looked_far(
else:
head_suffix = '_h{}'.format(head)

all_nb_usage = []
all_mean_att = []
all_looked_far = []
all_importance = []
all_maxp = []
all_nb_looked = []
for s in range(len(sft.streamlines)):
a = att_type[s][layer][head, :, :]
a, looked_far, importance, max_p, nb_lookedd = \
a, mean_att, nb_usage, looked_far, max_p, nb_looked = \
prepare_projections_from_options(
a, rescale_0_1, rescale_non_lin, rescale_z)
all_importance.append(importance[:, None])
all_mean_att.append(mean_att[:, None])
all_nb_usage.append(nb_usage[:, None])
all_looked_far.append(looked_far[:, None])
all_maxp.append(max_p[:, None])
all_nb_looked.append(nb_lookedd[:, None])
all_nb_looked.append(nb_looked[:, None])

# Save results for this attention, head, layer
filename_prefix = prefix_name + attention_names[i] + \
layer_prefix + head_suffix

# 1) IMPORTANCE
sft.data_per_point['importance'] = all_importance
_color_sft_from_dpp(sft, 'importance',
options_importance['cmap'],
options_importance['vmin'],
options_importance['vmax'],
title=explanation)
filename_trk = filename_prefix + '_importance.trk'
# Mean att: not fixing the vmin, vmax.
name = 'x_mean_att'
sft.data_per_point[name] = all_mean_att
_color_sft_from_dpp(sft, name, **options_range_length)
filename_trk = filename_prefix + '_' + name + '.trk'
print("Saving {} with dpp {}"
.format(filename_trk, list(sft.data_per_point.keys())))
save_tractogram(sft, filename_trk, bbox_valid_check=False)
# plt.savefig(filename_cbar)
del sft.data_per_point['importance']
del sft.data_per_point[name]
del sft.data_per_point['color']

# 2) LOOKED_FAR (mean pos where looked)
sft.data_per_point['looked_far'] = all_looked_far
_color_sft_from_dpp(sft, 'looked_far',
options_range_length['cmap'],
options_range_length['vmin'],
options_range_length['vmax'],
title=explanation)
filename_trk = filename_prefix + '_looked_far.trk'
print("Saving {} with dpp {}"
.format(filename_trk, list(sft.data_per_point.keys())))
save_tractogram(sft, filename_trk, bbox_valid_check=False)
# plt.savefig(filename_cbar)
del sft.data_per_point['looked_far']
del sft.data_per_point['color']

# 3) MAX_POS (0 = looked at current point)
sft.data_per_point['max_pos'] = all_maxp
_color_sft_from_dpp(sft, 'max_pos',
options_range_length['cmap'],
options_range_length['vmin'],
options_range_length['vmax'],
title=explanation)
filename_trk = filename_prefix + '_max_position.trk'
print("Saving {} with dpp {}"
.format(filename_trk, list(sft.data_per_point.keys())))
save_tractogram(sft, filename_trk, bbox_valid_check=False)
# plt.savefig(filename_cbar)
del sft.data_per_point['max_pos']
del sft.data_per_point['color']

# 4) NB LOOKED
sft.data_per_point['nb_looked'] = all_nb_looked
_color_sft_from_dpp(sft, 'nb_looked',
options_range_length['cmap'],
options_range_length['vmin'],
options_range_length['vmax'],
title=explanation)
filename_trk = filename_prefix + '_nb_looked.trk'
print("Saving {} with dpp {}"
.format(filename_trk, list(sft.data_per_point.keys())))
save_tractogram(sft, filename_trk, bbox_valid_check=False)
# plt.savefig(filename_cbar)
del sft.data_per_point['nb_looked']
del sft.data_per_point['color']


def _color_sft_from_dpp(sft, key, map_name='viridis', mmin=None, mmax=None,
title=None):
cmap = get_colormap(map_name)
# Others: Range is 0 - 1 (where 1 = 100% of streamline length)
# Not saving the colorbar.
names = ('x_nb_usage',
'y_looked_far', 'y_max_pos', 'y_nb_looked')
data = (all_nb_usage,
all_looked_far, all_maxp, all_nb_looked)
for name, vectors in zip(names, data):
sft.data_per_point[name] = vectors
_color_sft_from_dpp(sft, name, **options_range_length)
filename_trk = filename_prefix + '_' + name + '.trk'
print("Saving {} with dpp {}"
.format(filename_trk,
list(sft.data_per_point.keys())))
save_tractogram(sft, filename_trk, bbox_valid_check=False)
del sft.data_per_point[name]
del sft.data_per_point['color']


def _color_sft_from_dpp(sft, key, cmap='viridis', vmin=None, vmax=None,
prepare_fig: bool = False, title=None, **kw):
cmap = get_colormap(cmap)
tmp = [np.squeeze(sft.data_per_point[key][s]) for s in range(len(sft))]
data = np.hstack(tmp)

mmin = mmin or np.min(data)
mmax = mmax or np.max(data)
vmin = vmin or np.min(data)
vmax = vmax or np.max(data)
data = data - np.min(data)
data = data / np.max(data)
color = cmap(data)[:, 0:3] * 255
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color

# Preparing a figure
fig = plt.figure(figsize=(9, 3))
plt.imshow(np.array([[1, 0]]), cmap=cmap, vmin=mmin, vmax=mmax)
plt.gca().set_visible(False)
cax = plt.axes([0.1, 0.1, 0.6, 0.85])
plt.colorbar(orientation="horizontal", cax=cax, aspect=0.01)
if title is not None:
plt.title("Colorbar for key: {}\n".format(key) + title)
else:
plt.title("Colorbar for key: {}".format(key))
fig = None
if prepare_fig:
fig = plt.figure(figsize=(9, 3))
plt.imshow(np.array([[1, 0]]), cmap=cmap, vmin=vmin, vmax=vmax)
plt.gca().set_visible(False)
cax = plt.axes([0.1, 0.1, 0.6, 0.85])
plt.colorbar(orientation="horizontal", cax=cax, aspect=0.01)
if title is not None:
plt.title("Colorbar for key: {}\n".format(key) + title)
else:
plt.title("Colorbar for key: {}".format(key))

return sft, fig
16 changes: 8 additions & 8 deletions dwi_ml/testing/projects/tt_visu_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
encoder_decoder_show_head_view, encoder_decoder_show_model_view,
encoder_show_model_view, encoder_show_head_view)
from dwi_ml.testing.projects.tt_visu_colored_sft import (
color_sft_duplicate_lines, color_sft_importance_looked_far)
color_sft_duplicate_lines, color_sft_x_y_projections)
from dwi_ml.testing.projects.tt_visu_matrix import show_model_view_as_imshow
from dwi_ml.testing.projects.tt_visu_utils import (
prepare_encoder_tokens, prepare_decoder_tokens,
Expand Down Expand Up @@ -201,13 +201,13 @@ def _visu_encoder_decoder(
lengths, args.rescale_0_1, args.rescale_z, args.rescale_non_lin)

if has_decoder:
attention_names = ('encoder',)
else:
attention_names = ('encoder', 'decoder', 'cross')
else:
attention_names = ('encoder',)

if args.color_multi_length:
print(
"\n\n-------------- Preparing the colors for each length of "
"\n-------------- Preparing the colors for each length of "
"each streamline --------------")
color_sft_duplicate_lines(sft, lengths, prefix_name, weights,
attention_names, average_heads,
Expand All @@ -216,10 +216,10 @@ def _visu_encoder_decoder(

if args.color_x_y_summary:
print(
"\n\n-------------- Preparing the colors summary (importance, "
"where looked) for each streamline --------------")
color_sft_importance_looked_far(
sft, lengths, prefix_name, weights, attention_names,
"\n-------------- Preparing the colors summary (nb_usage, "
"where looked, etc) for each streamline --------------")
color_sft_x_y_projections(
sft, prefix_name, weights, attention_names,
average_heads, average_layers, args.group_with_max,
args.rescale_0_1, args.rescale_non_lin, args.rescale_z,
explanation)
Expand Down
42 changes: 27 additions & 15 deletions dwi_ml/testing/projects/tt_visu_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def show_model_view_as_imshow(
size_x = len(tokens_x)
size_y = len(tokens_y)

(options_main, options_importance, options_range_length,
explanation, rescale_name) = get_visu_params_from_options(
rescale_0_1, rescale_non_lin, rescale_z, size_x, size_y)
(options_main, options_range_length, explanation, rescale_name,
thresh) = get_visu_params_from_options(
rescale_0_1, rescale_non_lin, rescale_z)

for i in range(nb_layers):
att = attention_one_line[i]
Expand All @@ -30,31 +30,34 @@ def show_model_view_as_imshow(
axs = [axs]

for h in range(nb_heads):
a, where_looked, importance, maxp, nb_looked = \
a, mean_att, importance, looked_far, max_pos, nb_looked = \
prepare_projections_from_options(
att[h, :, :], rescale_0_1, rescale_non_lin, rescale_z)

divider = make_axes_locatable(axs[h])
ax_mean_att = divider.append_axes("bottom", size=0.2, pad=0)
ax_importance = divider.append_axes("bottom", size=0.2, pad=0)
ax_lookedfar = divider.append_axes("right", size=0.2, pad=0)
ax_max = divider.append_axes("right", size=0.2, pad=0)
ax_nb_looked = divider.append_axes("right", size=0.2, pad=0)
ax_cbar_main = divider.append_axes("right", size=0.3, pad=0.3)
ax_cbar_length = divider.append_axes("right", size=0.3, pad=0.55)

# Plot the main image
im_main = axs[h].imshow(a, **options_main)

# Bottom and right images
_ = ax_mean_att.imshow(mean_att[None, :],
**options_main, aspect='auto')
im_b = ax_importance.imshow(importance[None, :],
**options_importance, aspect='auto')
_ = ax_lookedfar.imshow(where_looked[:, None],
**options_range_length, aspect='auto')
_ = ax_lookedfar.imshow(looked_far[:, None],
**options_range_length, aspect='auto')
_ = ax_max.imshow(maxp[:, None],
_ = ax_max.imshow(max_pos[:, None],
**options_range_length, aspect='auto')
_ = ax_nb_looked.imshow(nb_looked[:, None],
**options_range_length, aspect='auto')

# Plot the main image
im_main = axs[h].imshow(a, **options_main)

# Set the titles (see also suptitle below)
if average_heads:
if group_with_max:
Expand All @@ -65,30 +68,39 @@ def show_model_view_as_imshow(
.format(rescale_name))
else:
axs[h].set_title("Head {}".format(h))

# Titles proj X
ax_mean_att.set_ylabel("Mean", rotation=0, labelpad=25)
ax_importance.set_ylabel("Importance.", rotation=0, labelpad=25)

# Titles proj Y
ax_lookedfar.set_title("Looked far", rotation=45, loc='left')
ax_max.set_title("Max pos", rotation=45, loc='left')
ax_nb_looked.set_title("Nb looked", rotation=45, loc='left')
ax_importance.set_ylabel("Importance.", rotation=0, labelpad=25)
# ("Importance" is a bit too close to last tick. Tried to use
# loc='bottom' but then ignores labelpad).

# Set the ticks with tokens.
# Main image: set the ticks with tokens.
axs[h].set_xticks(np.arange(size_x), fontsize=10)
axs[h].set_yticks(np.arange(size_y), fontsize=10)
axs[h].tick_params(axis='x', pad=20)
axs[h].set_xticklabels(tokens_x, rotation=-90)
axs[h].set_yticklabels(tokens_y)

# Move x ticks under the projections
axs[h].tick_params(axis='x', pad=40)

# Other plots: Hide ticks.
for ax in [ax_importance, ax_lookedfar, ax_max, ax_nb_looked]:
for ax in [ax_mean_att, ax_importance,
ax_lookedfar, ax_max, ax_nb_looked]:
plt.setp(ax.get_xticklabels(), visible=False)
plt.setp(ax.get_yticklabels(), visible=False)

# Set the colorbars, with titles.
# ToDo. Colorbar mean_att.
fig.colorbar(im_main, cax=ax_cbar_main)
ax_cbar_main.set_ylabel('Main figure', rotation=90, labelpad=-55)
fig.colorbar(im_b, cax=ax_cbar_length)
ax_cbar_length.set_ylabel('x / y projections: [0, length]',
ax_cbar_length.set_ylabel('x / y projections: %% of length',
rotation=90, labelpad=-55)

if average_layers:
Expand Down
Loading

0 comments on commit f79bacb

Please sign in to comment.