diff --git a/dlnlputils/recipe_utils.py b/dlnlputils/recipe_utils.py index 21f8c28..2aa7ad5 100644 --- a/dlnlputils/recipe_utils.py +++ b/dlnlputils/recipe_utils.py @@ -135,7 +135,7 @@ def plot_recipe_statistics(correct_recipes, total_recipes=None): def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues): - cm = confusion_matrix(y_true, y_pred, classes) + cm = confusion_matrix(y_true, y_pred, labels = classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] diff --git a/task5_text_transformer.ipynb b/task5_text_transformer.ipynb index b8628c6..37576c0 100644 --- a/task5_text_transformer.ipynb +++ b/task5_text_transformer.ipynb @@ -812,7 +812,7 @@ " is_training,\n", " weights_dropout):\n", " \"\"\"\n", - " queries - BatchSize x ValuesLen x HeadN x KeySize\n", + " queries - BatchSize x QueriesLen x HeadN x KeySize\n", " keys - BatchSize x KeysLen x HeadN x KeySize\n", " values - BatchSize x KeysLen x HeadN x ValueSize\n", " keys_padding_mask - BatchSize x KeysLen\n", @@ -821,12 +821,12 @@ " weights_dropout - float\n", " \n", " result - tuple of two:\n", - " - BatchSize x ValuesLen x HeadN x ValueSize - resulting features\n", - " - BatchSize x ValuesLen x KeysLen x HeadN - attention map\n", + " - BatchSize x QueriesLen x HeadN x ValueSize - resulting features\n", + " - BatchSize x QueriesLen x KeysLen x HeadN - attention map\n", " \"\"\"\n", "\n", - " # BatchSize x ValuesLen x KeysLen x HeadN\n", - " relevances = torch.einsum('bvhs,bkhs->bvkh', (queries, keys))\n", + " # BatchSize x QueriesLen x KeysLen x HeadN\n", + " relevances = torch.einsum('bqhs,bkhs->bqkh', (queries, keys))\n", " \n", " # замаскировать элементы, выходящие за длины последовательностей ключей\n", " padding_mask_expanded = keys_padding_mask[:, None, :, None].expand_as(relevances)\n", @@ -838,15 +838,15 @@ " normed_rels = F.softmax(relevances, dim=2) \n", " normed_rels = F.dropout(normed_rels, weights_dropout, is_training)\n", " \n", - " # BatchSize x ValuesLen x KeysLen x HeadN x 1\n", + " # BatchSize x QueriesLen x KeysLen x HeadN x 1\n", " normed_rels_expanded = normed_rels.unsqueeze(-1)\n", " \n", " # BatchSize x 1 x KeysLen x HeadN x ValueSize\n", " values_expanded = values.unsqueeze(1)\n", " \n", - " # BatchSize x ValuesLen x KeysLen x HeadN x ValueSize\n", + " # BatchSize x QueriesLen x KeysLen x HeadN x ValueSize\n", " weighted_values = normed_rels_expanded * values_expanded\n", - " result = weighted_values.sum(2) # BatchSize x ValuesLen x HeadN x ValueSize\n", + " result = weighted_values.sum(2) # BatchSize x QueriesLen x HeadN x ValueSize\n", " \n", " return result, normed_rels" ] diff --git a/task6_recipe_ner.ipynb b/task6_recipe_ner.ipynb index f40dc91..795aaca 100644 --- a/task6_recipe_ner.ipynb +++ b/task6_recipe_ner.ipynb @@ -310,7 +310,7 @@ " \n", " \n", " if i % 500 == 0:\n", - " liveplot.update({'negative log likelihood loss': loss})\n", + " liveplot.update({'negative log likelihood loss': loss.detach()})\n", " liveplot.draw()\n", " \n", " \n",