From c14a578da2ec14e961b7104c1b5eb38399df9e1a Mon Sep 17 00:00:00 2001 From: liaozy Date: Tue, 7 Mar 2023 21:16:00 +0800 Subject: [PATCH 1/7] issue --- chapter_13_chapter_computer-vision/ssd.ipynb | 274 ++++++++++++++++++- 1 file changed, 263 insertions(+), 11 deletions(-) diff --git a/chapter_13_chapter_computer-vision/ssd.ipynb b/chapter_13_chapter_computer-vision/ssd.ipynb index 80d6164..a3743b4 100644 --- a/chapter_13_chapter_computer-vision/ssd.ipynb +++ b/chapter_13_chapter_computer-vision/ssd.ipynb @@ -515,9 +515,87 @@ "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)" ] }, + { + "cell_type": "markdown", + "id": "5b25ac81", + "metadata": {}, + "source": [ + "# debug" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d717ec27", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "read 1000 training examples\n", + "read 100 validation examples\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "import mindspore.dataset as ds\n", + "import numpy as np\n", + "\n", + "def read_data_bananas(is_train=True):\n", + " \"\"\"读取香蕉检测数据集中的图像和标签\"\"\"\n", + " data_dir = d2l.download_extract('banana-detection')\n", + " csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')\n", + " csv_data = pd.read_csv(csv_fname)\n", + " csv_data = csv_data.set_index('img_name')\n", + " images, targets = [], []\n", + " for img_name, target in csv_data.iterrows():\n", + " images.append(ds.vision.read_image(\n", + " os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))\n", + " # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),\n", + " # 其中所有图像都具有相同的香蕉类(索引为0)\n", + " targets.append(list(target))\n", + " return images, mindspore.Tensor(targets, dtype=mindspore.float32).unsqueeze(1) / 255\n", + "\n", + "\n", + "class BananasDataset():\n", + " \"\"\"一个用于加载香蕉检测数据集的自定义数据集\"\"\"\n", + " def __init__(self, is_train):\n", + " self.parent = None\n", + " self.features, self.labels = read_data_bananas(is_train)\n", + " print('read ' + str(len(self.features)) + (f' training examples' if\n", + " is_train else f' validation examples'))\n", + "\n", + " def __getitem__(self, idx):\n", + " print(idx)\n", + " return (np.array(self.features[int(idx)], dtype='float32'), self.labels[int(idx)])\n", + "\n", + " def __len__(self):\n", + " return len(self.features)\n", + "\n", + "\n", + "\n", + "def load_data_bananas(batch_size):\n", + " \"\"\"加载香蕉检测数据集\"\"\"\n", + " train_iter = ds.GeneratorDataset(source=BananasDataset(is_train=True),\n", + " column_names=['imgs', 'labels'], shuffle=True)\n", + " val_iter = ds.GeneratorDataset(source=BananasDataset(is_train=False),\n", + " column_names=['imgs', 'labels'], shuffle=False)\n", + " train_iter = train_iter.map(mindspore.dataset.vision.HWC2CHW(),input_columns='imgs')\n", + " train_iter = train_iter.batch(batch_size=batch_size, drop_remainder=True)\n", + " val_iter = val_iter.map(mindspore.dataset.vision.HWC2CHW(),input_columns='imgs')\n", + " val_iter = val_iter.batch(batch_size=batch_size, drop_remainder=True)\n", + " return train_iter, val_iter\n", + "\n", + "batch_size = 32\n", + "train_iter, _ = load_data_bananas(batch_size)\n" + ] + }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, "id": "d80ea58d", "metadata": {}, "outputs": [ @@ -525,6 +603,46 @@ "name": "stdout", "output_type": "stream", "text": [ + "342\n", + "811\n", + "892\n", + "322\n", + "782\n", + "253\n", + "765\n", + "269\n", + "656\n", + "825\n", + "213\n", + "673\n", + "502\n", + "440\n", + "26\n", + "144\n", + "455\n", + "791\n", + "945\n", + "545\n", + "145\n", + "667\n", + "857\n", + "192\n", + "154\n", + "740\n", + "445\n", + "134\n", + "75\n", + "139\n", + "115\n", + "476\n", + "534\n", + "34\n", + "426\n", + "731\n", + "920\n", + "751\n", + "404\n", + "556\n", "(32, 3, 256, 256) (32, 1, 5)\n" ] } @@ -538,7 +656,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "id": "ca1c0d14", "metadata": {}, "outputs": [ @@ -561,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "id": "d1f7601f", "metadata": { "scrolled": true @@ -571,8 +689,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.896 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n", - "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.965 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n" + "[WARNING] KERNEL(52379,2b5b250b1d80,python):2023-03-07-21:09:42.502.653 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n", + "[WARNING] KERNEL(52379,2b5b250b1d80,python):2023-03-07-21:09:42.502.718 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n" ] }, { @@ -594,17 +712,17 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 28, "id": "993302d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Tensor(shape=[], dtype=Float32, value= 0.69856)" + "Tensor(shape=[], dtype=Float32, value= 0.698252)" ] }, - "execution_count": 22, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -614,6 +732,66 @@ "l" ] }, + { + "cell_type": "code", + "execution_count": 29, + "id": "21bb0a0c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "638\n", + "612\n", + "10\n", + "198\n", + "896\n", + "582\n", + "783\n", + "304\n", + "781\n", + "271\n", + "873\n", + "588\n", + "532\n", + "310\n", + "770\n", + "599\n", + "931\n", + "119\n", + "805\n", + "128\n", + "249\n", + "196\n", + "856\n", + "120\n", + "932\n", + "869\n", + "657\n", + "823\n", + "989\n", + "964\n", + "855\n", + "277\n", + "308\n", + "892\n", + "923\n", + "948\n", + "368\n", + "688\n", + "721\n", + "(32, 3, 256, 256) (32, 1, 5)\n" + ] + } + ], + "source": [ + "#forward\n", + "for X, Y in train_iter:\n", + " print(X.shape, Y.shape)\n", + " break" + ] + }, { "cell_type": "markdown", "id": "9c78354a", @@ -622,6 +800,80 @@ "# BUG 直接运行的话,这里会卡住" ] }, + { + "cell_type": "code", + "execution_count": 30, + "id": "9ca822e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "73\n", + "432\n", + "151\n", + "576\n", + "755\n", + "405\n", + "366\n", + "889\n", + "701\n", + "146\n", + "611\n", + "377\n", + "426\n", + "387\n", + "410\n", + "526\n", + "919\n", + "276\n", + "867\n", + "106\n", + "599\n", + "561\n", + "615\n", + "915\n", + "542\n", + "380\n", + "534\n", + "788\n", + "863\n", + "317\n", + "902\n", + "121\n", + "24\n", + "350\n", + "535\n", + "398\n", + "847\n", + "575\n", + "136\n", + "output anchors: (1, 5444, 4)\n", + "output class preds: (32, 5444, 2)\n", + "output bbox preds: (32, 21776)\n", + "bbox_labels: (32, 21776) (32, 21776)\n", + "bbox_masks: (32, 21776) Float32\n", + "cls_labels: (32, 5444) Int32\n" + ] + } + ], + "source": [ + "X, Y = next(train_iter.create_tuple_iterator())\n", + "anchors, cls_preds, bbox_preds = net(X)\n", + "print('output anchors:', anchors.shape)\n", + "print('output class preds:', cls_preds.shape)\n", + "print('output bbox preds:', bbox_preds.shape)\n", + "\n", + "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", + "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", + "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", + "print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + "\n", + "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", + " bbox_masks)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -636,14 +888,14 @@ " print('output anchors:', anchors.shape)\n", " print('output class preds:', cls_preds.shape)\n", " print('output bbox preds:', bbox_preds.shape)\n", - " \n", + "\n", " bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", " print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", " print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - " print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + " print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", " \n", " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)\n", + " bbox_masks)\n", " print(l)\n", " break" ] From 42a3aa63ca4893a6f20c16f8cc68c5e372b58eba Mon Sep 17 00:00:00 2001 From: liaozy Date: Tue, 14 Mar 2023 21:30:41 +0800 Subject: [PATCH 2/7] fix_bug --- chapter_11_optimization/lr-scheduler.ipynb | 5112 ++++++++++++++------ 1 file changed, 3721 insertions(+), 1391 deletions(-) diff --git a/chapter_11_optimization/lr-scheduler.ipynb b/chapter_11_optimization/lr-scheduler.ipynb index 397fb76..000ff0d 100644 --- a/chapter_11_optimization/lr-scheduler.ipynb +++ b/chapter_11_optimization/lr-scheduler.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 16, "id": "47980ea0", "metadata": {}, "outputs": [], @@ -30,20 +30,6 @@ "from d2l import mindspore as d2l\n", "from mindspore import value_and_grad\n", "\n", - "# def net_fn():\n", - "# model = nn.SequentialCell(\n", - "# nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.MaxPool2d(kernel_size=2, stride=2),\n", - "# nn.Conv2d(6, 16, kernel_size=5, pad_mode='valid', weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.MaxPool2d(kernel_size=2, stride=2),\n", - "# nn.Flatten(),\n", - "# nn.Dense(16 * 5 * 5, 120, weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.Dense(120, 84, weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.Dense(84, 10, weight_init='xavier_uniform'))\n", - "\n", - "# return model\n", - "\n", - "\n", "def net_fn():\n", " model = nn.SequentialCell(\n", " nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='HeUniform'), nn.ReLU(),\n", @@ -104,36 +90,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, "id": "f69d3a9d", "metadata": {}, - "outputs": [], - "source": [ - "lr, num_epochs = 0.3, 30\n", - "net = net_fn()\n", - "trainer = nn.SGD(net.trainable_params(), lr)\n", - "# train(net, train_iter, test_iter, num_epochs, loss, trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "b6142bca", - "metadata": {}, - "source": [ - "# BUG" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "87b461e5", - "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train loss 694451963869065584049010050596864.000, train acc 0.209, test acc 0.301\n" + "train loss 0.098, train acc 0.963, test acc 0.885\n" ] }, { @@ -142,12 +107,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-23T20:58:52.168125\n", + " 2023-03-14T20:30:49.846707\n", " image/svg+xml\n", " \n", " \n", @@ -163,8 +128,8 @@ " \n", " \n", " \n", @@ -183,16 +148,16 @@ " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -227,18 +192,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -566,22 +533,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -600,17 +567,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -620,17 +587,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -640,17 +607,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -660,17 +627,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -680,39 +647,284 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -913,13 +1125,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -933,13 +1145,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -956,7 +1168,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -971,19 +1183,10 @@ } ], "source": [ - "lr, num_epochs = 0.3, 5\n", - "net10 = nn.SequentialCell(\n", - " nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='HeUniform'), nn.ReLU(),\n", - " nn.MaxPool2d(kernel_size=2, stride=2),\n", - " nn.Conv2d(6, 16, kernel_size=5, pad_mode='valid', weight_init='HeUniform'), nn.ReLU(),\n", - " nn.MaxPool2d(kernel_size=2, stride=2),\n", - " nn.Flatten(),\n", - " nn.Dense(16 * 5 * 5, 120, weight_init='HeUniform'), nn.ReLU(),\n", - " nn.Dense(120, 84, weight_init='HeUniform'), nn.ReLU(),\n", - " nn.Dense(84, 10, weight_init='HeUniform'))\n", - "\n", - "trainer = nn.SGD(net10.trainable_params(), lr)\n", - "train(net10, train_iter, test_iter, num_epochs, loss, trainer)" + "lr, num_epochs = 0.3, 30\n", + "net = net_fn()\n", + "trainer = nn.SGD(net.trainable_params(), lr)\n", + "train(net, train_iter, test_iter, num_epochs, loss, trainer)" ] }, { @@ -996,7 +1199,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "id": "786b7035", "metadata": {}, "outputs": [ @@ -1004,7 +1207,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "learning rate is now 0.30\n" + "learning rate is now 0.10\n" ] } ], @@ -1016,7 +1219,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "7b859ea3", "metadata": {}, "outputs": [], @@ -1031,7 +1234,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "02f69ee7", "metadata": {}, "outputs": [ @@ -1046,7 +1249,7 @@ " \n", " \n", " \n", - " 2023-02-23T21:00:24.300243\n", + " 2023-03-14T16:59:16.363791\n", " image/svg+xml\n", " \n", " \n", @@ -1082,16 +1285,16 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1128,11 +1331,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1163,11 +1366,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1208,11 +1411,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1263,16 +1466,16 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1298,11 +1501,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1340,11 +1543,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1393,11 +1596,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1455,11 +1658,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1504,7 +1707,7 @@ "L 210.643838 138.512248 \n", "L 216.766095 139.018898 \n", "L 222.888352 139.5 \n", - "\" clip-path=\"url(#pc8a5adcc1c)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1558,10 +1761,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "71bbba03", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss 0.245, train acc 0.910, test acc 0.887\n" + ] + }, { "data": { "image/svg+xml": [ @@ -1573,7 +1783,7 @@ " \n", " \n", " \n", - " 2023-02-23T21:04:10.181314\n", + " 2023-03-14T17:17:37.381938\n", " image/svg+xml\n", " \n", " \n", @@ -1609,16 +1819,16 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1655,11 +1865,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1700,11 +1910,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1735,11 +1945,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1754,11 +1964,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1799,11 +2009,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1818,11 +2028,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1997,16 +2207,16 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2031,11 +2241,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2051,11 +2261,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2071,11 +2281,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2091,11 +2301,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2109,109 +2319,283 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2486,82 +2870,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "c271bc7a", "metadata": {}, - "outputs": [], - "source": [ - "class FactorScheduler:\n", - " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", - " self.factor = factor\n", - " self.stop_factor_lr = stop_factor_lr\n", - " self.base_lr = base_lr\n", - "\n", - " def __call__(self, num_update):\n", - " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", - " return self.base_lr\n", - "\n", - "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", - "d2l.plot(d2l.arange(50), [scheduler(t) for t in range(50)])" - ] - }, - { - "cell_type": "markdown", - "id": "eb5f8ed5", - "metadata": {}, - "source": [ - "#### 11.11.3.2. 多因子调度器" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0828bec8", - "metadata": {}, - "outputs": [], - "source": [ - "class MultiFactorScheduler:\n", - " def __init__(self, step, factor, base_lr):\n", - " self.step = step\n", - " self.factor = factor\n", - " self.base_lr = base_lr\n", - "\n", - " def __call__(self, epoch):\n", - " if epoch in self.step:\n", - " self.base_lr = self.base_lr * self.factor\n", - " return self.base_lr\n", - " else:\n", - " return self.base_lr\n", - "\n", - "scheduler = MultiFactorScheduler(step=[15, 30], factor=0.5, base_lr=0.5)\n", - "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "111352a0", - "metadata": {}, - "outputs": [], - "source": [ - "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", - "trainer = nn.SGD(net.trainable_params(), lr_list)\n", - "train(net, train_iter, test_iter, num_epochs, loss, trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "4c27a8f8", - "metadata": {}, - "source": [ - "#### 11.11.3.3. 余弦调度器" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "8e8a5c3e", - "metadata": {}, "outputs": [ { "data": { @@ -2569,12 +2880,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:05.292287\n", + " 2023-03-14T17:17:37.758737\n", " image/svg+xml\n", " \n", " \n", @@ -2590,41 +2901,41 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class FactorScheduler:\n", + " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", + " self.factor = factor\n", + " self.stop_factor_lr = stop_factor_lr\n", + " self.base_lr = base_lr\n", + "\n", + " def __call__(self, num_update):\n", + " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", + " return self.base_lr\n", + "\n", + "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", + "d2l.plot(d2l.arange(50), [scheduler(t) for t in range(50)])" + ] + }, + { + "cell_type": "markdown", + "id": "eb5f8ed5", + "metadata": {}, + "source": [ + "#### 11.11.3.2. 多因子调度器" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0828bec8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-03-14T17:17:38.680374\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class MultiFactorScheduler:\n", + " def __init__(self, step, factor, base_lr):\n", + " self.step = step\n", + " self.factor = factor\n", + " self.base_lr = base_lr\n", + "\n", + " def __call__(self, epoch):\n", + " if epoch in self.step:\n", + " self.base_lr = self.base_lr * self.factor\n", + " return self.base_lr\n", + " else:\n", + " return self.base_lr\n", + "\n", + "scheduler = MultiFactorScheduler(step=[15, 30], factor=0.5, base_lr=0.5)\n", + "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "111352a0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss 0.273, train acc 0.896, test acc 0.867\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-03-14T17:35:42.344129\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -3018,15 +4992,23 @@ } ], "source": [ - "# 缺少 warmup, 所以自定义 CosineScheduler\n", - "scheduler = nn.CosineDecayLR(min_lr=0.01, max_lr=0.3, decay_steps=20)\n", - "scheduler(d2l.tensor(0))\n", - "d2l.plot(d2l.arange(num_epochs), scheduler(d2l.tensor([t for t in range(num_epochs)])))" + "net = net_fn()\n", + "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", + "trainer = nn.SGD(net.trainable_params(), lr_list)\n", + "train(net, train_iter, test_iter, num_epochs, loss, trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "4c27a8f8", + "metadata": {}, + "source": [ + "#### 11.11.3.3. 余弦调度器" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "60ba8445", "metadata": {}, "outputs": [ @@ -3036,12 +5018,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:05.569894\n", + " 2023-03-14T17:35:44.349617\n", " image/svg+xml\n", " \n", " \n", @@ -3057,41 +5039,41 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -3515,7 +5432,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "5312a7d2", "metadata": {}, "outputs": [ @@ -3523,7 +5440,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train loss 3089455025148719102164992.000, train acc 0.135, test acc 0.142\n" + "train loss 0.143, train acc 0.948, test acc 0.901\n" ] }, { @@ -3532,12 +5449,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:57.717367\n", + " 2023-03-14T17:55:22.876697\n", " image/svg+xml\n", " \n", " \n", @@ -3553,8 +5470,8 @@ " \n", " \n", " \n", @@ -3573,16 +5490,16 @@ " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3617,18 +5534,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3956,22 +5875,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3990,17 +5909,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4010,17 +5929,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4030,17 +5949,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4050,17 +5969,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4070,39 +5989,284 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4303,13 +6467,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4323,13 +6487,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4346,7 +6510,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4362,7 +6526,6 @@ ], "source": [ "net = net_fn()\n", - "\n", "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", "trainer = nn.SGD(net.trainable_params(), lr_list)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer)" @@ -4378,7 +6541,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "ee8a5ba8", "metadata": {}, "outputs": [ @@ -4388,12 +6551,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:57.990175\n", + " 2023-03-14T17:55:23.108335\n", " image/svg+xml\n", " \n", " \n", @@ -4408,42 +6571,42 @@ " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -4862,7 +6941,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "05973299", "metadata": {}, "outputs": [ @@ -4870,7 +6949,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train loss 66452015380.343, train acc 0.100, test acc 0.100\n" + "train loss 0.171, train acc 0.938, test acc 0.898\n" ] }, { @@ -4879,12 +6958,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:05:52.549844\n", + " 2023-03-14T18:16:29.354051\n", " image/svg+xml\n", " \n", " \n", @@ -4900,8 +6979,8 @@ " \n", " \n", " \n", @@ -4920,16 +6999,16 @@ " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4964,18 +7043,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5303,22 +7384,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5337,17 +7418,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5357,17 +7438,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5377,17 +7458,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5397,17 +7478,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5417,44 +7498,281 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5655,13 +7973,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5675,13 +7993,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5698,7 +8016,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5715,7 +8033,6 @@ "source": [ "net = net_fn()\n", "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", - "print(lr_list)\n", "trainer = nn.SGD(net.trainable_params(), lr_list)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer)" ] @@ -5745,7 +8062,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.10" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, From 6bedbda53130affbedc0d4534197b5ee26ff0417 Mon Sep 17 00:00:00 2001 From: liaozy Date: Tue, 7 Mar 2023 21:16:00 +0800 Subject: [PATCH 3/7] issue --- chapter_13_chapter_computer-vision/ssd.ipynb | 274 ++++++++++++++++++- 1 file changed, 263 insertions(+), 11 deletions(-) diff --git a/chapter_13_chapter_computer-vision/ssd.ipynb b/chapter_13_chapter_computer-vision/ssd.ipynb index 80d6164..a3743b4 100644 --- a/chapter_13_chapter_computer-vision/ssd.ipynb +++ b/chapter_13_chapter_computer-vision/ssd.ipynb @@ -515,9 +515,87 @@ "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)" ] }, + { + "cell_type": "markdown", + "id": "5b25ac81", + "metadata": {}, + "source": [ + "# debug" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d717ec27", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "read 1000 training examples\n", + "read 100 validation examples\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "import mindspore.dataset as ds\n", + "import numpy as np\n", + "\n", + "def read_data_bananas(is_train=True):\n", + " \"\"\"读取香蕉检测数据集中的图像和标签\"\"\"\n", + " data_dir = d2l.download_extract('banana-detection')\n", + " csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')\n", + " csv_data = pd.read_csv(csv_fname)\n", + " csv_data = csv_data.set_index('img_name')\n", + " images, targets = [], []\n", + " for img_name, target in csv_data.iterrows():\n", + " images.append(ds.vision.read_image(\n", + " os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))\n", + " # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),\n", + " # 其中所有图像都具有相同的香蕉类(索引为0)\n", + " targets.append(list(target))\n", + " return images, mindspore.Tensor(targets, dtype=mindspore.float32).unsqueeze(1) / 255\n", + "\n", + "\n", + "class BananasDataset():\n", + " \"\"\"一个用于加载香蕉检测数据集的自定义数据集\"\"\"\n", + " def __init__(self, is_train):\n", + " self.parent = None\n", + " self.features, self.labels = read_data_bananas(is_train)\n", + " print('read ' + str(len(self.features)) + (f' training examples' if\n", + " is_train else f' validation examples'))\n", + "\n", + " def __getitem__(self, idx):\n", + " print(idx)\n", + " return (np.array(self.features[int(idx)], dtype='float32'), self.labels[int(idx)])\n", + "\n", + " def __len__(self):\n", + " return len(self.features)\n", + "\n", + "\n", + "\n", + "def load_data_bananas(batch_size):\n", + " \"\"\"加载香蕉检测数据集\"\"\"\n", + " train_iter = ds.GeneratorDataset(source=BananasDataset(is_train=True),\n", + " column_names=['imgs', 'labels'], shuffle=True)\n", + " val_iter = ds.GeneratorDataset(source=BananasDataset(is_train=False),\n", + " column_names=['imgs', 'labels'], shuffle=False)\n", + " train_iter = train_iter.map(mindspore.dataset.vision.HWC2CHW(),input_columns='imgs')\n", + " train_iter = train_iter.batch(batch_size=batch_size, drop_remainder=True)\n", + " val_iter = val_iter.map(mindspore.dataset.vision.HWC2CHW(),input_columns='imgs')\n", + " val_iter = val_iter.batch(batch_size=batch_size, drop_remainder=True)\n", + " return train_iter, val_iter\n", + "\n", + "batch_size = 32\n", + "train_iter, _ = load_data_bananas(batch_size)\n" + ] + }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, "id": "d80ea58d", "metadata": {}, "outputs": [ @@ -525,6 +603,46 @@ "name": "stdout", "output_type": "stream", "text": [ + "342\n", + "811\n", + "892\n", + "322\n", + "782\n", + "253\n", + "765\n", + "269\n", + "656\n", + "825\n", + "213\n", + "673\n", + "502\n", + "440\n", + "26\n", + "144\n", + "455\n", + "791\n", + "945\n", + "545\n", + "145\n", + "667\n", + "857\n", + "192\n", + "154\n", + "740\n", + "445\n", + "134\n", + "75\n", + "139\n", + "115\n", + "476\n", + "534\n", + "34\n", + "426\n", + "731\n", + "920\n", + "751\n", + "404\n", + "556\n", "(32, 3, 256, 256) (32, 1, 5)\n" ] } @@ -538,7 +656,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "id": "ca1c0d14", "metadata": {}, "outputs": [ @@ -561,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "id": "d1f7601f", "metadata": { "scrolled": true @@ -571,8 +689,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.896 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n", - "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.965 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n" + "[WARNING] KERNEL(52379,2b5b250b1d80,python):2023-03-07-21:09:42.502.653 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n", + "[WARNING] KERNEL(52379,2b5b250b1d80,python):2023-03-07-21:09:42.502.718 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n" ] }, { @@ -594,17 +712,17 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 28, "id": "993302d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Tensor(shape=[], dtype=Float32, value= 0.69856)" + "Tensor(shape=[], dtype=Float32, value= 0.698252)" ] }, - "execution_count": 22, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -614,6 +732,66 @@ "l" ] }, + { + "cell_type": "code", + "execution_count": 29, + "id": "21bb0a0c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "638\n", + "612\n", + "10\n", + "198\n", + "896\n", + "582\n", + "783\n", + "304\n", + "781\n", + "271\n", + "873\n", + "588\n", + "532\n", + "310\n", + "770\n", + "599\n", + "931\n", + "119\n", + "805\n", + "128\n", + "249\n", + "196\n", + "856\n", + "120\n", + "932\n", + "869\n", + "657\n", + "823\n", + "989\n", + "964\n", + "855\n", + "277\n", + "308\n", + "892\n", + "923\n", + "948\n", + "368\n", + "688\n", + "721\n", + "(32, 3, 256, 256) (32, 1, 5)\n" + ] + } + ], + "source": [ + "#forward\n", + "for X, Y in train_iter:\n", + " print(X.shape, Y.shape)\n", + " break" + ] + }, { "cell_type": "markdown", "id": "9c78354a", @@ -622,6 +800,80 @@ "# BUG 直接运行的话,这里会卡住" ] }, + { + "cell_type": "code", + "execution_count": 30, + "id": "9ca822e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "73\n", + "432\n", + "151\n", + "576\n", + "755\n", + "405\n", + "366\n", + "889\n", + "701\n", + "146\n", + "611\n", + "377\n", + "426\n", + "387\n", + "410\n", + "526\n", + "919\n", + "276\n", + "867\n", + "106\n", + "599\n", + "561\n", + "615\n", + "915\n", + "542\n", + "380\n", + "534\n", + "788\n", + "863\n", + "317\n", + "902\n", + "121\n", + "24\n", + "350\n", + "535\n", + "398\n", + "847\n", + "575\n", + "136\n", + "output anchors: (1, 5444, 4)\n", + "output class preds: (32, 5444, 2)\n", + "output bbox preds: (32, 21776)\n", + "bbox_labels: (32, 21776) (32, 21776)\n", + "bbox_masks: (32, 21776) Float32\n", + "cls_labels: (32, 5444) Int32\n" + ] + } + ], + "source": [ + "X, Y = next(train_iter.create_tuple_iterator())\n", + "anchors, cls_preds, bbox_preds = net(X)\n", + "print('output anchors:', anchors.shape)\n", + "print('output class preds:', cls_preds.shape)\n", + "print('output bbox preds:', bbox_preds.shape)\n", + "\n", + "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", + "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", + "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", + "print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + "\n", + "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", + " bbox_masks)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -636,14 +888,14 @@ " print('output anchors:', anchors.shape)\n", " print('output class preds:', cls_preds.shape)\n", " print('output bbox preds:', bbox_preds.shape)\n", - " \n", + "\n", " bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", " print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", " print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - " print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + " print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", " \n", " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)\n", + " bbox_masks)\n", " print(l)\n", " break" ] From 5fb8e515328d1d6a2b9139c4dbbb25b8bb3d467d Mon Sep 17 00:00:00 2001 From: liaozy Date: Tue, 14 Mar 2023 21:30:41 +0800 Subject: [PATCH 4/7] fix_bug --- chapter_11_optimization/lr-scheduler.ipynb | 5112 ++++++++++++++------ 1 file changed, 3721 insertions(+), 1391 deletions(-) diff --git a/chapter_11_optimization/lr-scheduler.ipynb b/chapter_11_optimization/lr-scheduler.ipynb index 397fb76..000ff0d 100644 --- a/chapter_11_optimization/lr-scheduler.ipynb +++ b/chapter_11_optimization/lr-scheduler.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 16, "id": "47980ea0", "metadata": {}, "outputs": [], @@ -30,20 +30,6 @@ "from d2l import mindspore as d2l\n", "from mindspore import value_and_grad\n", "\n", - "# def net_fn():\n", - "# model = nn.SequentialCell(\n", - "# nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.MaxPool2d(kernel_size=2, stride=2),\n", - "# nn.Conv2d(6, 16, kernel_size=5, pad_mode='valid', weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.MaxPool2d(kernel_size=2, stride=2),\n", - "# nn.Flatten(),\n", - "# nn.Dense(16 * 5 * 5, 120, weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.Dense(120, 84, weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.Dense(84, 10, weight_init='xavier_uniform'))\n", - "\n", - "# return model\n", - "\n", - "\n", "def net_fn():\n", " model = nn.SequentialCell(\n", " nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='HeUniform'), nn.ReLU(),\n", @@ -104,36 +90,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, "id": "f69d3a9d", "metadata": {}, - "outputs": [], - "source": [ - "lr, num_epochs = 0.3, 30\n", - "net = net_fn()\n", - "trainer = nn.SGD(net.trainable_params(), lr)\n", - "# train(net, train_iter, test_iter, num_epochs, loss, trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "b6142bca", - "metadata": {}, - "source": [ - "# BUG" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "87b461e5", - "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train loss 694451963869065584049010050596864.000, train acc 0.209, test acc 0.301\n" + "train loss 0.098, train acc 0.963, test acc 0.885\n" ] }, { @@ -142,12 +107,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-23T20:58:52.168125\n", + " 2023-03-14T20:30:49.846707\n", " image/svg+xml\n", " \n", " \n", @@ -163,8 +128,8 @@ " \n", " \n", " \n", @@ -183,16 +148,16 @@ " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -227,18 +192,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -566,22 +533,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -600,17 +567,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -620,17 +587,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -640,17 +607,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -660,17 +627,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p3b15d4f73a)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -680,39 +647,284 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -913,13 +1125,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -933,13 +1145,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -956,7 +1168,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -971,19 +1183,10 @@ } ], "source": [ - "lr, num_epochs = 0.3, 5\n", - "net10 = nn.SequentialCell(\n", - " nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='HeUniform'), nn.ReLU(),\n", - " nn.MaxPool2d(kernel_size=2, stride=2),\n", - " nn.Conv2d(6, 16, kernel_size=5, pad_mode='valid', weight_init='HeUniform'), nn.ReLU(),\n", - " nn.MaxPool2d(kernel_size=2, stride=2),\n", - " nn.Flatten(),\n", - " nn.Dense(16 * 5 * 5, 120, weight_init='HeUniform'), nn.ReLU(),\n", - " nn.Dense(120, 84, weight_init='HeUniform'), nn.ReLU(),\n", - " nn.Dense(84, 10, weight_init='HeUniform'))\n", - "\n", - "trainer = nn.SGD(net10.trainable_params(), lr)\n", - "train(net10, train_iter, test_iter, num_epochs, loss, trainer)" + "lr, num_epochs = 0.3, 30\n", + "net = net_fn()\n", + "trainer = nn.SGD(net.trainable_params(), lr)\n", + "train(net, train_iter, test_iter, num_epochs, loss, trainer)" ] }, { @@ -996,7 +1199,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "id": "786b7035", "metadata": {}, "outputs": [ @@ -1004,7 +1207,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "learning rate is now 0.30\n" + "learning rate is now 0.10\n" ] } ], @@ -1016,7 +1219,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "7b859ea3", "metadata": {}, "outputs": [], @@ -1031,7 +1234,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "02f69ee7", "metadata": {}, "outputs": [ @@ -1046,7 +1249,7 @@ " \n", " \n", " \n", - " 2023-02-23T21:00:24.300243\n", + " 2023-03-14T16:59:16.363791\n", " image/svg+xml\n", " \n", " \n", @@ -1082,16 +1285,16 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1128,11 +1331,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1163,11 +1366,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1208,11 +1411,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1263,16 +1466,16 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1298,11 +1501,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1340,11 +1543,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1393,11 +1596,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1455,11 +1658,11 @@ " \n", " \n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1504,7 +1707,7 @@ "L 210.643838 138.512248 \n", "L 216.766095 139.018898 \n", "L 222.888352 139.5 \n", - "\" clip-path=\"url(#pc8a5adcc1c)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", + "\" clip-path=\"url(#p5a63d737d9)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1558,10 +1761,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "71bbba03", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss 0.245, train acc 0.910, test acc 0.887\n" + ] + }, { "data": { "image/svg+xml": [ @@ -1573,7 +1783,7 @@ " \n", " \n", " \n", - " 2023-02-23T21:04:10.181314\n", + " 2023-03-14T17:17:37.381938\n", " image/svg+xml\n", " \n", " \n", @@ -1609,16 +1819,16 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1655,11 +1865,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1700,11 +1910,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1735,11 +1945,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1754,11 +1964,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1799,11 +2009,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1818,11 +2028,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1997,16 +2207,16 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2031,11 +2241,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2051,11 +2261,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2071,11 +2281,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2091,11 +2301,11 @@ " \n", " \n", + "\" clip-path=\"url(#pb855921b33)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2109,109 +2319,283 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2486,82 +2870,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "c271bc7a", "metadata": {}, - "outputs": [], - "source": [ - "class FactorScheduler:\n", - " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", - " self.factor = factor\n", - " self.stop_factor_lr = stop_factor_lr\n", - " self.base_lr = base_lr\n", - "\n", - " def __call__(self, num_update):\n", - " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", - " return self.base_lr\n", - "\n", - "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", - "d2l.plot(d2l.arange(50), [scheduler(t) for t in range(50)])" - ] - }, - { - "cell_type": "markdown", - "id": "eb5f8ed5", - "metadata": {}, - "source": [ - "#### 11.11.3.2. 多因子调度器" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0828bec8", - "metadata": {}, - "outputs": [], - "source": [ - "class MultiFactorScheduler:\n", - " def __init__(self, step, factor, base_lr):\n", - " self.step = step\n", - " self.factor = factor\n", - " self.base_lr = base_lr\n", - "\n", - " def __call__(self, epoch):\n", - " if epoch in self.step:\n", - " self.base_lr = self.base_lr * self.factor\n", - " return self.base_lr\n", - " else:\n", - " return self.base_lr\n", - "\n", - "scheduler = MultiFactorScheduler(step=[15, 30], factor=0.5, base_lr=0.5)\n", - "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "111352a0", - "metadata": {}, - "outputs": [], - "source": [ - "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", - "trainer = nn.SGD(net.trainable_params(), lr_list)\n", - "train(net, train_iter, test_iter, num_epochs, loss, trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "4c27a8f8", - "metadata": {}, - "source": [ - "#### 11.11.3.3. 余弦调度器" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "8e8a5c3e", - "metadata": {}, "outputs": [ { "data": { @@ -2569,12 +2880,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:05.292287\n", + " 2023-03-14T17:17:37.758737\n", " image/svg+xml\n", " \n", " \n", @@ -2590,41 +2901,41 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class FactorScheduler:\n", + " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", + " self.factor = factor\n", + " self.stop_factor_lr = stop_factor_lr\n", + " self.base_lr = base_lr\n", + "\n", + " def __call__(self, num_update):\n", + " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", + " return self.base_lr\n", + "\n", + "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", + "d2l.plot(d2l.arange(50), [scheduler(t) for t in range(50)])" + ] + }, + { + "cell_type": "markdown", + "id": "eb5f8ed5", + "metadata": {}, + "source": [ + "#### 11.11.3.2. 多因子调度器" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0828bec8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-03-14T17:17:38.680374\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class MultiFactorScheduler:\n", + " def __init__(self, step, factor, base_lr):\n", + " self.step = step\n", + " self.factor = factor\n", + " self.base_lr = base_lr\n", + "\n", + " def __call__(self, epoch):\n", + " if epoch in self.step:\n", + " self.base_lr = self.base_lr * self.factor\n", + " return self.base_lr\n", + " else:\n", + " return self.base_lr\n", + "\n", + "scheduler = MultiFactorScheduler(step=[15, 30], factor=0.5, base_lr=0.5)\n", + "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "111352a0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss 0.273, train acc 0.896, test acc 0.867\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-03-14T17:35:42.344129\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -3018,15 +4992,23 @@ } ], "source": [ - "# 缺少 warmup, 所以自定义 CosineScheduler\n", - "scheduler = nn.CosineDecayLR(min_lr=0.01, max_lr=0.3, decay_steps=20)\n", - "scheduler(d2l.tensor(0))\n", - "d2l.plot(d2l.arange(num_epochs), scheduler(d2l.tensor([t for t in range(num_epochs)])))" + "net = net_fn()\n", + "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", + "trainer = nn.SGD(net.trainable_params(), lr_list)\n", + "train(net, train_iter, test_iter, num_epochs, loss, trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "4c27a8f8", + "metadata": {}, + "source": [ + "#### 11.11.3.3. 余弦调度器" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "60ba8445", "metadata": {}, "outputs": [ @@ -3036,12 +5018,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:05.569894\n", + " 2023-03-14T17:35:44.349617\n", " image/svg+xml\n", " \n", " \n", @@ -3057,41 +5039,41 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -3515,7 +5432,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "5312a7d2", "metadata": {}, "outputs": [ @@ -3523,7 +5440,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train loss 3089455025148719102164992.000, train acc 0.135, test acc 0.142\n" + "train loss 0.143, train acc 0.948, test acc 0.901\n" ] }, { @@ -3532,12 +5449,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:57.717367\n", + " 2023-03-14T17:55:22.876697\n", " image/svg+xml\n", " \n", " \n", @@ -3553,8 +5470,8 @@ " \n", " \n", " \n", @@ -3573,16 +5490,16 @@ " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3617,18 +5534,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3956,22 +5875,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3990,17 +5909,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4010,17 +5929,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4030,17 +5949,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4050,17 +5969,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p616c191902)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4070,39 +5989,284 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4303,13 +6467,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4323,13 +6487,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4346,7 +6510,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4362,7 +6526,6 @@ ], "source": [ "net = net_fn()\n", - "\n", "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", "trainer = nn.SGD(net.trainable_params(), lr_list)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer)" @@ -4378,7 +6541,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "ee8a5ba8", "metadata": {}, "outputs": [ @@ -4388,12 +6551,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:57.990175\n", + " 2023-03-14T17:55:23.108335\n", " image/svg+xml\n", " \n", " \n", @@ -4408,42 +6571,42 @@ " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -4862,7 +6941,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "05973299", "metadata": {}, "outputs": [ @@ -4870,7 +6949,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train loss 66452015380.343, train acc 0.100, test acc 0.100\n" + "train loss 0.171, train acc 0.938, test acc 0.898\n" ] }, { @@ -4879,12 +6958,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:05:52.549844\n", + " 2023-03-14T18:16:29.354051\n", " image/svg+xml\n", " \n", " \n", @@ -4900,8 +6979,8 @@ " \n", " \n", " \n", @@ -4920,16 +6999,16 @@ " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4964,18 +7043,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5303,22 +7384,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5337,17 +7418,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5357,17 +7438,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5377,17 +7458,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5397,17 +7478,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#paa7a148b49)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5417,44 +7498,281 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5655,13 +7973,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5675,13 +7993,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5698,7 +8016,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5715,7 +8033,6 @@ "source": [ "net = net_fn()\n", "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", - "print(lr_list)\n", "trainer = nn.SGD(net.trainable_params(), lr_list)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer)" ] @@ -5745,7 +8062,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.10" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, From 00597ef4996a20b5c11306e2fb75eb33a85ffe0f Mon Sep 17 00:00:00 2001 From: liaozy Date: Sat, 1 Apr 2023 09:41:51 +0800 Subject: [PATCH 5/7] fix_bug --- chapter_11_optimization/lr-scheduler.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chapter_11_optimization/lr-scheduler.ipynb b/chapter_11_optimization/lr-scheduler.ipynb index 000ff0d..4b6a715 100644 --- a/chapter_11_optimization/lr-scheduler.ipynb +++ b/chapter_11_optimization/lr-scheduler.ipynb @@ -1212,8 +1212,8 @@ } ], "source": [ - "# 改写learning_rate参数的值\n", - "mindspore.ops.assign(trainer.learning_rate, d2l.tensor(0.1))\n", + "lr = 0.1\n", + "mindspore.ops.assign(trainer.learning_rate, lr)\n", "print(f'learning rate is now {trainer.learning_rate.data.asnumpy():.2f}')" ] }, From a343a224f867f4b212348a32e278346f0f3859c2 Mon Sep 17 00:00:00 2001 From: liaozy Date: Mon, 3 Apr 2023 17:49:11 +0800 Subject: [PATCH 6/7] ssd --- chapter_13_chapter_computer-vision/ssd.ipynb | 3545 +++++++++++++++--- 1 file changed, 3065 insertions(+), 480 deletions(-) diff --git a/chapter_13_chapter_computer-vision/ssd.ipynb b/chapter_13_chapter_computer-vision/ssd.ipynb index a3743b4..086fb2c 100644 --- a/chapter_13_chapter_computer-vision/ssd.ipynb +++ b/chapter_13_chapter_computer-vision/ssd.ipynb @@ -29,11 +29,20 @@ "execution_count": 1, "id": "e6901d6d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(300375:47194544307584,MainProcess):2023-04-03-16:15:37.357.548 [mindspore/common/_decorator.py:40] 'Fills' is deprecated from version 2.0 and will be removed in a future version, use 'ops.fill' instead.\n" + ] + } + ], "source": [ "%matplotlib inline\n", "import mindspore\n", "from mindspore import nn\n", + "import numpy as np\n", "from d2l import mindspore as d2l\n", "\n", "\n", @@ -103,7 +112,7 @@ "outputs": [], "source": [ "def flatten_pred(pred):\n", - " return d2l.flatten(pred.permute(0, 2, 3, 1)) # flatten不改变0轴的size\n", + " return d2l.flatten(pred.permute(0, 2, 3, 1))\n", "\n", "def concat_preds(preds):\n", " return d2l.concat([flatten_pred(p) for p in preds], axis=1)" @@ -250,7 +259,7 @@ "source": [ "def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):\n", " Y = blk(X)\n", - " anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio) ##### 需要第四章\n", + " anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)\n", " cls_preds = cls_predictor(Y)\n", " bbox_preds = bbox_predictor(Y)\n", " return (Y, anchors, cls_preds, bbox_preds)" @@ -363,7 +372,7 @@ ], "source": [ "batch_size = 32\n", - "train_iter, _ = d2l.load_data_bananas(batch_size) # " + "train_iter, _ = d2l.load_data_bananas(batch_size)" ] }, { @@ -413,570 +422,3146 @@ "source": [ "def cls_eval(cls_preds, cls_labels):\n", " # 由于类别预测结果放在最后一维,argmax需要指定最后一维。\n", - " return float((cls_preds.argmax(axis=-1).type(\n", + " return float((cls_preds.argmax(axis=-1).astype(\n", " cls_labels.dtype) == cls_labels).sum())\n", "\n", "def bbox_eval(bbox_preds, bbox_labels, bbox_masks):\n", - "# return float((d2l.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())\n", " return float(((bbox_labels - bbox_preds) * bbox_masks).abs().sum())" ] }, - { - "cell_type": "code", - "execution_count": 18, - "id": "259d2abc", - "metadata": {}, - "outputs": [], - "source": [ - "from mindspore import ops\n", - "def box_iou(boxes1, boxes2):\n", - " \"\"\"计算两个锚框或边界框列表中成对的交并比\"\"\"\n", - " box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n", - " (boxes[:, 3] - boxes[:, 1]))\n", - " # boxes1,boxes2,areas1,areas2的形状:\n", - " # boxes1:(boxes1的数量,4),\n", - " # boxes2:(boxes2的数量,4),\n", - " # areas1:(boxes1的数量,),\n", - " # areas2:(boxes2的数量,)\n", - " areas1 = box_area(boxes1)\n", - " areas2 = box_area(boxes2)\n", - " # inter_upperlefts,inter_lowerrights,inters的形状:\n", - " # (boxes1的数量,boxes2的数量,2)\n", - " inter_upperlefts = ops.maximum(boxes1[:, None, :2], boxes2[:, :2])\n", - " inter_lowerrights = ops.minimum(boxes1[:, None, 2:], boxes2[:, 2:])\n", - " inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n", - " # inter_areasandunion_areas的形状:(boxes1的数量,boxes2的数量)\n", - " inter_areas = inters[:, :, 0] * inters[:, :, 1]\n", - " union_areas = areas1[:, None] + areas2 - inter_areas\n", - " return inter_areas / union_areas\n", - "\n", - "def assign_anchor_to_bbox(ground_truth, anchors, iou_threshold=0.5):\n", - " \"\"\"将最接近的真实边界框分配给锚框\"\"\"\n", - " num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n", - " # 位于第i行和第j列的元素x_ij是锚框i和真实边界框j的IoU\n", - " jaccard = box_iou(anchors, ground_truth)\n", - " # 对于每个锚框,分配的真实边界框的张量\n", - " anchors_bbox_map = ops.full((num_anchors,), -1, dtype=mindspore.int64)\n", - " # 根据阈值,决定是否分配真实边界框\n", - " indices, max_ious = ops.max(jaccard, axis=1)\n", - " anc_i = ops.nonzero(max_ious >= iou_threshold).reshape(-1)\n", - " box_j = ops.masked_select(indices, max_ious >= iou_threshold)\n", - " anchors_bbox_map[anc_i] = box_j \n", - " col_discard = ops.full((num_anchors,), -1)\n", - " row_discard = ops.full((num_gt_boxes,), -1)\n", - " for _ in range(num_gt_boxes):\n", - " max_idx = ops.argmax(jaccard)\n", - " box_idx = (max_idx % num_gt_boxes).long()\n", - " anc_idx = (max_idx / num_gt_boxes).long()\n", - " anchors_bbox_map[anc_idx] = box_idx\n", - " jaccard[:, box_idx] = col_discard\n", - " jaccard[anc_idx, :] = row_discard\n", - " return anchors_bbox_map\n", - "\n", - "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n", - " \"\"\"对锚框偏移量的转换\"\"\"\n", - " c_anc = d2l.box_corner_to_center(anchors)\n", - " c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n", - " offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n", - " offset_wh = 5 * ops.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n", - " offset = ops.concat([offset_xy, offset_wh], axis=1)\n", - " return offset\n", - "\n", - "def multibox_target(anchors, labels):\n", - " \"\"\"使用真实边界框标记锚框\"\"\"\n", - " batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n", - " batch_offset, batch_mask, batch_class_labels = [], [], []\n", - " num_anchors = anchors.shape[0]\n", - " for i in range(batch_size):\n", - " label = labels[i, :, :]\n", - " anchors_bbox_map = assign_anchor_to_bbox(\n", - " label[:, 1:], anchors)\n", - " bbox_mask = ops.tile((anchors_bbox_map >= 0).float().unsqueeze(-1), (1, 4))\n", - " # 将类标签和分配的边界框坐标初始化为零\n", - " class_labels = ops.zeros(num_anchors, dtype=mindspore.int32)\n", - " assigned_bb = ops.zeros((num_anchors, 4), dtype=mindspore.float32)\n", - " # 使用真实边界框来标记锚框的类别。\n", - " # 如果一个锚框没有被分配,标记其为背景(值为零)\n", - " indices_true = ops.nonzero(anchors_bbox_map >= 0)\n", - " bb_idx = anchors_bbox_map[indices_true]\n", - " class_labels[indices_true] = label[bb_idx, 0].long() + 1 \n", - " assigned_bb[indices_true] = label[bb_idx, 1:]\n", - " # 偏移量转换\n", - " offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n", - " batch_offset.append(offset.reshape(-1))\n", - " batch_mask.append(bbox_mask.reshape(-1))\n", - " batch_class_labels.append(class_labels)\n", - " bbox_offset = ops.stack(batch_offset)\n", - " bbox_mask = ops.stack(batch_mask)\n", - " class_labels = ops.stack(batch_class_labels)\n", - " return (bbox_offset, bbox_mask, class_labels)\n", - "\n", - "# print(anchors.shape, Y.shape) # (1, 5444, 4) (32, 1, 5)\n", - "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)" - ] - }, { "cell_type": "markdown", - "id": "5b25ac81", + "id": "6912bb2d", "metadata": {}, "source": [ - "# debug" + "#### 13.7.2.3. 训练模型" ] }, { "cell_type": "code", - "execution_count": 24, - "id": "d717ec27", + "execution_count": 19, + "id": "3870f282", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "read 1000 training examples\n", - "read 100 validation examples\n" + "class err 2.85e-03, bbox mae 2.70e-03\n", + "6.8 examples/sec on \n" ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:09:36.597677\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "import os\n", - "import pandas as pd\n", - "import mindspore.dataset as ds\n", - "import numpy as np\n", - "\n", - "def read_data_bananas(is_train=True):\n", - " \"\"\"读取香蕉检测数据集中的图像和标签\"\"\"\n", - " data_dir = d2l.download_extract('banana-detection')\n", - " csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'label.csv')\n", - " csv_data = pd.read_csv(csv_fname)\n", - " csv_data = csv_data.set_index('img_name')\n", - " images, targets = [], []\n", - " for img_name, target in csv_data.iterrows():\n", - " images.append(ds.vision.read_image(\n", - " os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))\n", - " # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),\n", - " # 其中所有图像都具有相同的香蕉类(索引为0)\n", - " targets.append(list(target))\n", - " return images, mindspore.Tensor(targets, dtype=mindspore.float32).unsqueeze(1) / 255\n", - "\n", - "\n", - "class BananasDataset():\n", - " \"\"\"一个用于加载香蕉检测数据集的自定义数据集\"\"\"\n", - " def __init__(self, is_train):\n", - " self.parent = None\n", - " self.features, self.labels = read_data_bananas(is_train)\n", - " print('read ' + str(len(self.features)) + (f' training examples' if\n", - " is_train else f' validation examples'))\n", - "\n", - " def __getitem__(self, idx):\n", - " print(idx)\n", - " return (np.array(self.features[int(idx)], dtype='float32'), self.labels[int(idx)])\n", - "\n", - " def __len__(self):\n", - " return len(self.features)\n", - "\n", - "\n", - "\n", - "def load_data_bananas(batch_size):\n", - " \"\"\"加载香蕉检测数据集\"\"\"\n", - " train_iter = ds.GeneratorDataset(source=BananasDataset(is_train=True),\n", - " column_names=['imgs', 'labels'], shuffle=True)\n", - " val_iter = ds.GeneratorDataset(source=BananasDataset(is_train=False),\n", - " column_names=['imgs', 'labels'], shuffle=False)\n", - " train_iter = train_iter.map(mindspore.dataset.vision.HWC2CHW(),input_columns='imgs')\n", - " train_iter = train_iter.batch(batch_size=batch_size, drop_remainder=True)\n", - " val_iter = val_iter.map(mindspore.dataset.vision.HWC2CHW(),input_columns='imgs')\n", - " val_iter = val_iter.batch(batch_size=batch_size, drop_remainder=True)\n", - " return train_iter, val_iter\n", + "num_epochs, timer = 20, d2l.Timer()\n", + "animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", + " legend=['class error', 'bbox mae'])\n", "\n", - "batch_size = 32\n", - "train_iter, _ = load_data_bananas(batch_size)\n" + "def forward_fn(X, Y):\n", + " # 生成多尺度的锚框,为每个锚框预测类别和偏移量\n", + " anchors, cls_preds, bbox_preds = net(X)\n", + " # 为每个锚框标注类别和偏移量\n", + " bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)\n", + " # 根据类别和偏移量的预测和标注值计算损失函数\n", + " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", + " bbox_masks).mean()\n", + " return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", + " \n", + "grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True)\n", + " \n", + "def train_step(inputs, targets):\n", + " (l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels), grads = grad_fn(inputs, targets)\n", + " trainer(grads)\n", + " return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", + " \n", + "for epoch in range(num_epochs):\n", + " # 训练精确度的和,训练精确度的和中的示例数\n", + " # 绝对误差的和,绝对误差的和中的示例数\n", + " metric = d2l.Accumulator(4)\n", + " net.set_train()\n", + " for features, target in train_iter:\n", + " timer.start()\n", + " l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels = train_step(features, target)\n", + " metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),\n", + " bbox_eval(bbox_preds, bbox_labels, bbox_masks),\n", + " bbox_labels.numel())\n", + " cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]\n", + " animator.add(epoch + 1, (cls_err, bbox_mae))\n", + "print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')\n", + "print(f'{train_iter.get_dataset_size()/ timer.stop():.1f} examples/sec on ')" ] }, { - "cell_type": "code", - "execution_count": 25, - "id": "d80ea58d", + "cell_type": "markdown", + "id": "425916e0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "342\n", - "811\n", - "892\n", - "322\n", - "782\n", - "253\n", - "765\n", - "269\n", - "656\n", - "825\n", - "213\n", - "673\n", - "502\n", - "440\n", - "26\n", - "144\n", - "455\n", - "791\n", - "945\n", - "545\n", - "145\n", - "667\n", - "857\n", - "192\n", - "154\n", - "740\n", - "445\n", - "134\n", - "75\n", - "139\n", - "115\n", - "476\n", - "534\n", - "34\n", - "426\n", - "731\n", - "920\n", - "751\n", - "404\n", - "556\n", - "(32, 3, 256, 256) (32, 1, 5)\n" - ] - } - ], "source": [ - "#forward\n", - "for X, Y in train_iter:\n", - " print(X.shape, Y.shape)\n", - " break" + "### 13.7.3. 预测目标" ] }, { "cell_type": "code", - "execution_count": 26, - "id": "ca1c0d14", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "output anchors: (1, 5444, 4)\n", - "output class preds: (32, 5444, 2)\n", - "output bbox preds: (32, 21776)\n" - ] - } - ], + "execution_count": 20, + "id": "a3b40a56", + "metadata": { + "scrolled": true + }, + "outputs": [], "source": [ - "anchors, cls_preds, bbox_preds = net(X)\n", - "print('output anchors:', anchors.shape)\n", - "print('output class preds:', cls_preds.shape)\n", - "print('output bbox preds:', bbox_preds.shape)" + "img = mindspore.dataset.vision.read_image('../img/banana.jpg').astype('float32')\n", + "X = mindspore.Tensor(img.transpose(2,0,1), dtype=mindspore.float32).unsqueeze(0)" ] }, { "cell_type": "code", - "execution_count": 27, - "id": "d1f7601f", - "metadata": { - "scrolled": true - }, + "execution_count": 24, + "id": "ab17d783", + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[WARNING] KERNEL(52379,2b5b250b1d80,python):2023-03-07-21:09:42.502.653 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n", - "[WARNING] KERNEL(52379,2b5b250b1d80,python):2023-03-07-21:09:42.502.718 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bbox_labels: (32, 21776) (32, 21776)\n", - "bbox_masks: (32, 21776) Float32\n", - "cls_labels: (32, 5444) Int32\n" + "[WARNING] KERNEL(300375,2aec9fa87700,python):2023-04-03-17:11:46.720.731 [mindspore/ccsrc/kernel/kernel.h:514] CheckShapeNull] For 'Reshape', the shape of input cannot contain zero, but got (0, 1)\n", + "[WARNING] KERNEL(300375,2aec9fa87700,python):2023-04-03-17:11:46.723.968 [mindspore/ccsrc/kernel/kernel.h:514] CheckShapeNull] For 'Add', the shape of input_0 cannot contain zero, but got (0)\n", + "[WARNING] KERNEL(300375,2aec9fa87700,python):2023-04-03-17:11:46.734.618 [mindspore/ccsrc/kernel/kernel.h:514] CheckShapeNull] For 'Gather', the shape of indices cannot contain zero, but got (0)\n", + "<__array_function__ internals>:200: RuntimeWarning: invalid value encountered in cast\n" ] } ], "source": [ - "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - "print('cls_labels:', cls_labels.shape, cls_labels.dtype)" + "def predict(X):\n", + " net.set_train(False)\n", + " anchors, cls_preds, bbox_preds = net(X)\n", + " cls_probs = d2l.softmax(cls_preds, axis=2).permute(0, 2, 1)\n", + " output = d2l.multibox_detection(cls_probs, bbox_preds, anchors) # d2l.\n", + " idx = [i for i, row in enumerate(output[0]) if row[0] != -1]\n", + " return output[0][idx]\n", + "\n", + "output = predict(X)" ] }, { "cell_type": "code", - "execution_count": 28, - "id": "993302d8", + "execution_count": 26, + "id": "7c7d41f6", "metadata": {}, "outputs": [ { "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:18:28.618567\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], "text/plain": [ - "Tensor(shape=[], dtype=Float32, value= 0.698252)" + "
" ] }, - "execution_count": 28, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks).mean()\n", - "l" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "21bb0a0c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "638\n", - "612\n", - "10\n", - "198\n", - "896\n", - "582\n", - "783\n", - "304\n", - "781\n", - "271\n", - "873\n", - "588\n", - "532\n", - "310\n", - "770\n", - "599\n", - "931\n", - "119\n", - "805\n", - "128\n", - "249\n", - "196\n", - "856\n", - "120\n", - "932\n", - "869\n", - "657\n", - "823\n", - "989\n", - "964\n", - "855\n", - "277\n", - "308\n", - "892\n", - "923\n", - "948\n", - "368\n", - "688\n", - "721\n", - "(32, 3, 256, 256) (32, 1, 5)\n" - ] - } - ], - "source": [ - "#forward\n", - "for X, Y in train_iter:\n", - " print(X.shape, Y.shape)\n", - " break" + "def display(img, output, threshold):\n", + " d2l.set_figsize((5, 5))\n", + " fig = d2l.plt.imshow(img)\n", + " for row in output:\n", + " score = float(row[1])\n", + " if score < threshold:\n", + " continue\n", + " h, w = img.shape[0:2]\n", + " bbox = [row[2:6] * d2l.tensor((w, h, w, h))]\n", + " d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')\n", + "img = X.numpy().squeeze().transpose(1,2,0)/255\n", + "display(img, output, threshold=0.9)" ] }, { "cell_type": "markdown", - "id": "9c78354a", + "id": "a49128d0", "metadata": {}, "source": [ - "# BUG 直接运行的话,这里会卡住" + "### 13.7.5. 练习" ] }, { "cell_type": "code", - "execution_count": 30, - "id": "9ca822e6", + "execution_count": 27, + "id": "77b6ea35", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "73\n", - "432\n", - "151\n", - "576\n", - "755\n", - "405\n", - "366\n", - "889\n", - "701\n", - "146\n", - "611\n", - "377\n", - "426\n", - "387\n", - "410\n", - "526\n", - "919\n", - "276\n", - "867\n", - "106\n", - "599\n", - "561\n", - "615\n", - "915\n", - "542\n", - "380\n", - "534\n", - "788\n", - "863\n", - "317\n", - "902\n", - "121\n", - "24\n", - "350\n", - "535\n", - "398\n", - "847\n", - "575\n", - "136\n", - "output anchors: (1, 5444, 4)\n", - "output class preds: (32, 5444, 2)\n", - "output bbox preds: (32, 21776)\n", - "bbox_labels: (32, 21776) (32, 21776)\n", - "bbox_masks: (32, 21776) Float32\n", - "cls_labels: (32, 5444) Int32\n" - ] + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:18:32.085257\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "X, Y = next(train_iter.create_tuple_iterator())\n", - "anchors, cls_preds, bbox_preds = net(X)\n", - "print('output anchors:', anchors.shape)\n", - "print('output class preds:', cls_preds.shape)\n", - "print('output bbox preds:', bbox_preds.shape)\n", - "\n", - "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - "print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", - "\n", - "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5e0d0317", - "metadata": {}, - "outputs": [], - "source": [ - "#forward\n", - "for X, Y in train_iter:\n", - " print(X.shape, Y.shape)\n", - " anchors, cls_preds, bbox_preds = net(X)\n", - " print('output anchors:', anchors.shape)\n", - " print('output class preds:', cls_preds.shape)\n", - " print('output bbox preds:', bbox_preds.shape)\n", - "\n", - " bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - " print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - " print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - " print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + "def smooth_l1(data, scalar):\n", + " out = []\n", + " for i in data:\n", + " if abs(i) < 1 / (scalar ** 2):\n", + " out.append(((scalar * i) ** 2) / 2)\n", + " else:\n", + " out.append(abs(i) - 0.5 / (scalar ** 2))\n", " \n", - " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)\n", - " print(l)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ef20694", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y = next(train_iter)\n", - "print(X.shape, Y.shape)\n", - "anchors, cls_preds, bbox_preds = net(X)\n", - "print('output anchors:', anchors.shape)\n", - "print('output class preds:', cls_preds.shape)\n", - "print('output bbox preds:', bbox_preds.shape)\n", + " return np.array(out)\n", "\n", - "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - "print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + "sigmas = [10, 1, 0.5]\n", + "lines = ['-', '--', '-.']\n", + "x = d2l.arange(-2, 2, 0.1).numpy()\n", + "d2l.set_figsize()\n", "\n", - "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)\n", - "print(l)\n" - ] - }, - { - "cell_type": "markdown", - "id": "6912bb2d", - "metadata": {}, - "source": [ - "#### 13.7.2.3. 训练模型" + "for l, s in zip(lines, sigmas):\n", + " y = smooth_l1(x, scalar=s)\n", + " \n", + " d2l.plt.plot(x, y, l, label='sigma=%.1f' % s)\n", + "d2l.plt.legend();" ] }, { "cell_type": "code", - "execution_count": null, - "id": "3870f282", + "execution_count": 28, + "id": "289f5472", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:18:32.544692\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# num_epochs, timer = 20, d2l.Timer()\n", - "# animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", - "# legend=['class error', 'bbox mae'])\n", + "def focal_loss(gamma, x):\n", + " return -(1 - x) ** gamma * np.log(x)\n", "\n", - "# def forward_fn(X, Y):\n", - "# # 生成多尺度的锚框,为每个锚框预测类别和偏移量\n", - "# anchors, cls_preds, bbox_preds = net(X)\n", - "# # 为每个锚框标注类别和偏移量\n", - "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y) # d2l.\n", - "# # 根据类别和偏移量的预测和标注值计算损失函数\n", - "# l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - "# bbox_masks).mean()\n", - "# return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", - " \n", - "# grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True)\n", - " \n", - "# def train_step(inputs, targets):\n", - "# (l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels), grads = grad_fn(inputs, targets)\n", - "# trainer(grads)\n", - "# return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", - " \n", - "# for epoch in range(num_epochs):\n", - "# # 训练精确度的和,训练精确度的和中的示例数\n", - "# # 绝对误差的和,绝对误差的和中的示例数\n", - "# metric = d2l.Accumulator(4)\n", - "# net.set_train()\n", - "# for features, target in train_iter:\n", - "# timer.start()\n", - "# print(epoch)\n", - "# l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels = train_step(features, target)\n", - "# metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),\n", - "# bbox_eval(bbox_preds, bbox_labels, bbox_masks),\n", - "# bbox_labels.numel())\n", - "# cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]\n", - "# animator.add(epoch + 1, (cls_err, bbox_mae))\n", - "# print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')\n", - "# print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on '\n", - "# f'{str(device)}')" + "x = np.arange(0.01, 1, 0.01)\n", + "for l, gamma in zip(lines, [0, 1, 5]):\n", + " y = d2l.plt.plot(x, focal_loss(gamma, x), l, label='gamma=%.1f' % gamma)\n", + "d2l.plt.legend();" ] } ], @@ -996,7 +3581,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.10" }, "toc": { "base_numbering": 1, From c0c857683124d12fed4705d10845489c2cf1019f Mon Sep 17 00:00:00 2001 From: liaozy Date: Mon, 3 Apr 2023 18:53:46 +0800 Subject: [PATCH 7/7] fix_ssd --- d2l/mindspore.py | 119 ++++++++++++++++++++++++++++------------------- 1 file changed, 70 insertions(+), 49 deletions(-) diff --git a/d2l/mindspore.py b/d2l/mindspore.py index 7b9523b..5a30d75 100644 --- a/d2l/mindspore.py +++ b/d2l/mindspore.py @@ -2485,50 +2485,6 @@ def resnet_block(in_channels, out_channels, num_residuals, return net -def train_batch_ch13(train_step, X, y): - """用GPU进行小批量训练""" - l, pred = train_step(X, y) - train_loss_sum = l.sum() - train_acc_sum = d2l.accuracy(pred, y) - return train_loss_sum, train_acc_sum - -def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs): - """用GPU进行模型训练""" - timer, num_batches = d2l.Timer(), train_iter.get_dataset_size() - animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], - legend=['train loss', 'train acc', 'test acc']) - - def forward_fn(inputs, targets): - logits = net(inputs) - l = loss(logits, targets) - return l, logits - - grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True) - - def train_step(inputs, targets): - (l, logits), grads = grad_fn(inputs, targets) - trainer(grads) - return l, logits - - for epoch in range(num_epochs): - # 4个维度:储存训练损失,训练准确度,实例数,特点数 - metric = d2l.Accumulator(4) - net.set_train() - for i, (features, labels) in enumerate(train_iter): - timer.start() - l, acc = train_batch_ch13( - train_step, features, labels) # , loss, trainer, devices - metric.add(l, acc, labels.shape[0], labels.numel()) - timer.stop() - if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: - animator.add(epoch + (i + 1) / num_batches, - (metric[0] / metric[2], metric[1] / metric[3], - None)) - test_acc = d2l.evaluate_accuracy_gpu(net, test_iter) - animator.add(epoch + 1, (None, None, test_acc)) - print(f'loss {metric[0] / metric[2]:.3f}, train acc ' - f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}') - print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec') def multibox_prior(data, sizes, ratios): """生成以每个像素为中心具有不同形状的锚框""" @@ -2648,7 +2604,7 @@ def assign_anchor_to_bbox(ground_truth, anchors, iou_threshold=0.5): # 对于每个锚框,分配的真实边界框的张量 anchors_bbox_map = ops.full((num_anchors,), -1, dtype=mindspore.int64) # 根据阈值,决定是否分配真实边界框 - indices, max_ious = ops.max(jaccard, axis=1) + max_ious, indices = ops.max(jaccard, axis=1) anc_i = ops.nonzero(max_ious >= iou_threshold).reshape(-1) box_j = ops.masked_select(indices, max_ious >= iou_threshold) anchors_bbox_map[anc_i] = box_j @@ -2683,14 +2639,14 @@ def multibox_target(anchors, labels): label[:, 1:], anchors) bbox_mask = ops.tile((anchors_bbox_map >= 0).float().unsqueeze(-1), (1, 4)) # 将类标签和分配的边界框坐标初始化为零 - class_labels = ops.zeros(num_anchors, dtype=mindspore.int32) + class_labels = ops.zeros(num_anchors, dtype=mindspore.int64) assigned_bb = ops.zeros((num_anchors, 4), dtype=mindspore.float32) # 使用真实边界框来标记锚框的类别。 # 如果一个锚框没有被分配,标记其为背景(值为零) indices_true = ops.nonzero(anchors_bbox_map >= 0) bb_idx = anchors_bbox_map[indices_true] - class_labels[indices_true] = label[bb_idx, 0].long() + 1 - assigned_bb[indices_true] = label[bb_idx, 1:] + class_labels[indices_true] = label[:, 0][bb_idx].long() + 1 + assigned_bb[indices_true] = label[:, 1:][bb_idx] # 偏移量转换 offset = offset_boxes(anchors, assigned_bb) * bbox_mask batch_offset.append(offset.reshape(-1)) @@ -2698,7 +2654,7 @@ def multibox_target(anchors, labels): batch_class_labels.append(class_labels) bbox_offset = ops.stack(batch_offset) bbox_mask = ops.stack(batch_mask) - class_labels = ops.stack(batch_class_labels) + class_labels = ops.stack(batch_class_labels).astype('int32') return (bbox_offset, bbox_mask, class_labels) d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip', @@ -2746,6 +2702,70 @@ def load_data_bananas(batch_size): return train_iter, val_iter +def offset_inverse(anchors, offset_preds): + """根据带有预测偏移量的锚框来预测边界框""" + anc = d2l.box_corner_to_center(anchors) + pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2] + pred_bbox_wh = ops.exp(offset_preds[:, 2:] / 5) * anc[:, 2:] + pred_bbox = ops.concat((pred_bbox_xy, pred_bbox_wh), axis=1) + predicted_bbox = d2l.box_center_to_corner(pred_bbox) + return predicted_bbox + + +def nms(boxes, scores, iou_threshold): + """对预测边界框的置信度进行排序""" + B = ops.argsort(scores, axis=-1, descending=True) + keep = [] # 保留预测边界框的指标 + while B.numel() > 0: + i = B[0] + keep.append(i.asnumpy()) + if B.numel() == 1: break + iou = box_iou(boxes[i].reshape(-1, 4), + boxes[:, :][B[1:]].reshape(-1, 4)).reshape(-1) + inds = ops.nonzero(iou <= iou_threshold).reshape(-1) + B = B[inds + 1] + return mindspore.Tensor(keep) + + +def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5, + pos_threshold=0.009999999): + """使用非极大值抑制来预测边界框""" + batch_size = cls_probs.shape[0] + anchors = anchors.squeeze(0) + num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2] + out = [] + for i in range(batch_size): + cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4) + conf, class_id = ops.max(cls_prob[1:], 0) + predicted_bb = offset_inverse(anchors, offset_pred) + keep = nms(predicted_bb, conf, nms_threshold) + # 找到所有的non_keep索引,并将类设置为背景 + all_idx = ops.arange(num_anchors, dtype=mindspore.int64) + combined = ops.concat((keep, all_idx)) + unique = ops.unique(combined) + uniques, _ = unique[0], unique[1] + # 统计未重复的,mindspore的unique无法统计同一个元素出现次数 + counts = mnp.bincount(combined) + non_keep = mindspore.Tensor([int(uniques[i]) for i in range(len(counts)) + if (counts[i] == 1)]) + all_id_sorted = ops.concat((keep, non_keep)) + class_id[non_keep] = -1 + class_id = class_id[all_id_sorted] + conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted] + # pos_threshold是一个用于非背景预测的阈值 + below_min_idx = (conf < pos_threshold) + class_id = mindspore.Tensor.float(class_id) + class_id[below_min_idx] = -1 + for j in range(len(conf)): + if below_min_idx[j] == True: + conf[j] = 1 - conf[j] + + pred_info = ops.concat((class_id.unsqueeze(1), + conf.unsqueeze(1).astype('float32'), + predicted_bb), axis=1) + out.append(pred_info) + return ops.stack(out) + d2l.DATA_HUB['voc2012'] = (d2l.DATA_URL + 'VOCtrainval_11-May-2012.tar', '4e443f8a2eca6b1dac8a6c57641b67dd40621a49') @@ -3016,6 +3036,7 @@ def get_pooled_rois(self,feature_map, roi_batch): square = ops.square sqrt = ops.sqrt sign = ops.sign +softmax = ops.softmax meshgrid = ops.meshgrid linspace = ops.linspace zeros_like = ops.zeros_like