Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] Add Weight Visualization #1492

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add comments
Stanislav Ponkrashov committed Feb 8, 2023
commit 01c4c47f22b893044662495186364625cf6aeaa4
73 changes: 59 additions & 14 deletions media/CircleGraph/view-sidebar.js
Original file line number Diff line number Diff line change
@@ -625,16 +625,18 @@ sidebar.ParameterView = class {
}
};

function getTensorShape(array) {
// returns 'tensor' shape as an array
function getTensorShape(tensor) {
let shape = [];
let curArray = array;
while (Array.isArray(curArray)) {
shape.push(curArray.length);
curArray = curArray[0];
let currentTensor = tensor;
while (Array.isArray(currentTensor)) {
shape.push(currentTensor.length);
currentTensor = currentTensor[0];
}
return shape;
}

// returns tensor[index[0]][index[1]]...
function getTensorValue(tensor, index) {
let value = tensor;
for (const i of index) {
@@ -643,11 +645,15 @@ function getTensorValue(tensor, index) {
return value;
}

// returns 2D subarray of a 'tensor' as normalized data
// 'axis1' of a tensor corresponds to data y
// 'axis2' of a tensor corresponds to data x
// 'values' fix all other axes values for the tensor
function normalizeTensor(tensor, axis1, axis2, values) {
let shape = getTensorShape(tensor);
let height = shape[axis1];
let width = shape[axis2];
let imageData = [];
let normalizedData = [];

let index = values.slice();

@@ -664,20 +670,29 @@ function normalizeTensor(tensor, axis1, axis2, values) {
}
}

// normalize values and create imageData
// normalize values and create normalizedData
for (let i = 0; i < height; i++) {
let row = [];
for (let j = 0; j < width; j++) {
if (minValue === maxValue) {
row.push(0.5);
continue;
}
index[axis1] = i;
index[axis2] = j;
let value = getTensorValue(tensor, index);
row.push((value - minValue) / (maxValue - minValue));
}
imageData.push(row);
normalizedData.push(row);
}
return imageData;
return normalizedData;
}

// creates a canvas in the document and renders 'tensor' values heatmap in it
// 'axis1' of a tensor corresponds to heatmap y
// 'axis2' of a tensor corresponds to heatmap x
// 'values' fix all other axes values for the tensor
// values[axis1] and values[axis2] are ignored
function tensorToImage(tensor, axis1, axis2, values, document) {
const scale = 12;
let imageData = normalizeTensor(tensor, axis1, axis2, values);
@@ -689,6 +704,9 @@ function tensorToImage(tensor, axis1, axis2, values, document) {
let ctx = canvas.getContext('2d');
ctx.imageSmoothingEnabled = false;
let imageDataArray = new Uint8ClampedArray(height * width * 4);

// add a function to extract a value from client coordinates
// (for tooltip)
canvas.getValue = (clientX, clientY) => {
const rect = canvas.getBoundingClientRect();
const x = Math.floor(Math.min(Math.max(clientX - rect.left, 0), canvas.width - 0.1) / scale);
@@ -698,6 +716,11 @@ function tensorToImage(tensor, axis1, axis2, values, document) {
index[axis2] = x;
return getTensorValue(tensor, index);
};

// set image data based on normalized values of tensor
// a normalized value of 0.0 corresponds to blue
// a normalized value of 0.5 corresponds to green
// a normalized value of 1.0 corresponds to red
for (let i = 0; i < height; i++) {
for (let j = 0; j < width; j++) {
let value = imageData[i][j];
@@ -717,13 +740,14 @@ function tensorToImage(tensor, axis1, axis2, values, document) {
let imageDataImage = new ImageData(imageDataArray, width, height);
ctx.putImageData(imageDataImage, 0, 0);

// scale
// scale the canvas
ctx.scale(scale, scale);
ctx.drawImage(canvas, 0, 0);

return canvas;
}

// Creates UI that visualizes any 2D subarray of a tensor as a heatmap.
sidebar.VisualTensorView = class {
constructor(host, tensor) {
this._host = host;
@@ -736,14 +760,25 @@ sidebar.VisualTensorView = class {
}
this._element.className = "sidebar-view-item-value-line-border";

// _axes[0] - the axis that corresponds to heatmap x
// _axes[1] - the axis that corresponds to heatmap y
this._axes = [this._tensorShape.length - 2, this._tensorShape.length - 1];

// values of all other axes
this._values = Array(this._tensorShape.length);
this._values.fill(0);

this._checkboxes = [];
this._valueTexts = [];
this._infoTexts = [];
// if tensor has more than 2 axes, we need a UI to select 2 axes for heatmap
// and values for all other axes
if (this._tensorShape.length > 2) {
this._checkboxes = [];
this._valueTexts = [];
this._infoTexts = [];
// for each axis create a checkbox and a value selector
for (let i = 0; i < this._tensorShape.length; ++i) {
// create a label that shows maximum axis value
// and "(x)" or "(y)" respectively if the axis is selected
// for a heatmap
let infoText = this._host.document.createElement("div");
infoText.setAttribute('style', 'float: left;');
infoText.setText = () => {
@@ -757,6 +792,7 @@ sidebar.VisualTensorView = class {
};
this._infoTexts.push(infoText);

// create a checkbox to select the axis for a heatmap
let checkbox = this._host.document.createElement("input");
checkbox.setAttribute('style', 'float: left;');
checkbox.type = "checkbox";
@@ -776,6 +812,7 @@ sidebar.VisualTensorView = class {
});
this._checkboxes.push(checkbox);

// create a number input to set axis value
let valueText = this._host.document.createElement("input");
valueText.className = "sidebar-view-item-value-number";
valueText.type = "number";
@@ -818,24 +855,28 @@ sidebar.VisualTensorView = class {
});
this._valueTexts.push(valueText);

// create buttons to decrease/increase axis value by 1
let leftButton = this._host.document.createElement("div");
leftButton.className = "sidebar-view-item-value-expander";
leftButton.setAttribute('style', 'float: left; padding: 1px 4px 0px 4px;');
// left arrow symbol
leftButton.innerHTML = "&#x140A";
leftButton.addEventListener("click", valueText.decrease);

let rightButton = this._host.document.createElement("div");
rightButton.className = "sidebar-view-item-value-expander";
rightButton.setAttribute('style', 'float: left; padding: 1px 4px 0px 4px;');
// right arrow symbol
rightButton.innerHTML = "&#x1405";
rightButton.addEventListener("click", valueText.increase);

// append the UI elements to root
this._element.appendChild(checkbox);
this._element.appendChild(leftButton);
this._element.appendChild(valueText);
this._element.appendChild(rightButton);
this._element.appendChild(infoText);

// new line
this._element.appendChild(this._host.document.createElement("br"));
this._element.appendChild(this._host.document.createElement("br"));
}
@@ -846,13 +887,16 @@ sidebar.VisualTensorView = class {
this.updateUI();
}

// update selected axes' labels with "(x)" and "(y)" respectively
// and disable value selectors for selected axes
updateUI() {
for (let i = 0; i < this._valueTexts.length; ++i) {
this._infoTexts[i].setText();
this._valueTexts[i].disabled = this._axes.includes(i);
}
}

// render the image and append the canvas to root
updateImage() {
if (this._imageContainer) {
this._element.removeChild(this._imageContainer);
@@ -864,6 +908,7 @@ sidebar.VisualTensorView = class {
try {
this._imageContainer = this._host.document.createElement("div");
this._image = tensorToImage(this._tensor, this._axes[1], this._axes[0], this._values, this._host.document);
// add event listener to update tooltip on mouse move
this._image.addEventListener("mousemove", (e) => {
this._imageContainer.title = this._image.getValue(e.clientX, e.clientY);
}, false);