diff --git a/chapter_11_optimization/lr-scheduler.ipynb b/chapter_11_optimization/lr-scheduler.ipynb index 397fb76..f1316da 100644 --- a/chapter_11_optimization/lr-scheduler.ipynb +++ b/chapter_11_optimization/lr-scheduler.ipynb @@ -18,10 +18,18 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "47980ea0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(398508:47004611480960,MainProcess):2023-04-03-19:20:51.352.287 [mindspore/common/_decorator.py:40] 'Fills' is deprecated from version 2.0 and will be removed in a future version, use 'ops.fill' instead.\n" + ] + } + ], "source": [ "%matplotlib inline\n", "import math\n", @@ -30,20 +38,6 @@ "from d2l import mindspore as d2l\n", "from mindspore import value_and_grad\n", "\n", - "# def net_fn():\n", - "# model = nn.SequentialCell(\n", - "# nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.MaxPool2d(kernel_size=2, stride=2),\n", - "# nn.Conv2d(6, 16, kernel_size=5, pad_mode='valid', weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.MaxPool2d(kernel_size=2, stride=2),\n", - "# nn.Flatten(),\n", - "# nn.Dense(16 * 5 * 5, 120, weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.Dense(120, 84, weight_init='xavier_uniform'), nn.ReLU(),\n", - "# nn.Dense(84, 10, weight_init='xavier_uniform'))\n", - "\n", - "# return model\n", - "\n", - "\n", "def net_fn():\n", " model = nn.SequentialCell(\n", " nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='HeUniform'), nn.ReLU(),\n", @@ -104,36 +98,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "id": "f69d3a9d", "metadata": {}, - "outputs": [], - "source": [ - "lr, num_epochs = 0.3, 30\n", - "net = net_fn()\n", - "trainer = nn.SGD(net.trainable_params(), lr)\n", - "# train(net, train_iter, test_iter, num_epochs, loss, trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "b6142bca", - "metadata": {}, - "source": [ - "# BUG" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "87b461e5", - "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train loss 694451963869065584049010050596864.000, train acc 0.209, test acc 0.301\n" + "train loss 0.119, train acc 0.955, test acc 0.889\n" ] }, { @@ -142,12 +115,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-23T20:58:52.168125\n", + " 2023-04-03T19:24:42.750080\n", " image/svg+xml\n", " \n", " \n", @@ -163,8 +136,8 @@ " \n", " \n", " \n", @@ -183,16 +156,16 @@ " \n", " \n", + "\" clip-path=\"url(#pec68262993)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -227,18 +200,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -566,22 +541,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#pec68262993)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -600,17 +575,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#pec68262993)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -620,17 +595,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#pec68262993)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -640,17 +615,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#pec68262993)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -660,17 +635,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#pec68262993)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -680,39 +655,284 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -913,13 +1133,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -933,13 +1153,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -956,7 +1176,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -971,19 +1191,10 @@ } ], "source": [ - "lr, num_epochs = 0.3, 5\n", - "net10 = nn.SequentialCell(\n", - " nn.Conv2d(1, 6, kernel_size=5, padding=2, pad_mode='pad', weight_init='HeUniform'), nn.ReLU(),\n", - " nn.MaxPool2d(kernel_size=2, stride=2),\n", - " nn.Conv2d(6, 16, kernel_size=5, pad_mode='valid', weight_init='HeUniform'), nn.ReLU(),\n", - " nn.MaxPool2d(kernel_size=2, stride=2),\n", - " nn.Flatten(),\n", - " nn.Dense(16 * 5 * 5, 120, weight_init='HeUniform'), nn.ReLU(),\n", - " nn.Dense(120, 84, weight_init='HeUniform'), nn.ReLU(),\n", - " nn.Dense(84, 10, weight_init='HeUniform'))\n", - "\n", - "trainer = nn.SGD(net10.trainable_params(), lr)\n", - "train(net10, train_iter, test_iter, num_epochs, loss, trainer)" + "lr, num_epochs = 0.3, 30\n", + "net = net_fn()\n", + "trainer = nn.SGD(net.trainable_params(), lr)\n", + "train(net, train_iter, test_iter, num_epochs, loss, trainer)" ] }, { @@ -996,7 +1207,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "786b7035", "metadata": {}, "outputs": [ @@ -1004,19 +1215,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "learning rate is now 0.30\n" + "learning rate is now 0.10\n" ] } ], "source": [ - "# 改写learning_rate参数的值\n", - "mindspore.ops.assign(trainer.learning_rate, d2l.tensor(0.1))\n", + "lr = 0.1\n", + "mindspore.ops.assign(trainer.learning_rate, lr)\n", "print(f'learning rate is now {trainer.learning_rate.data.asnumpy():.2f}')" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "7b859ea3", "metadata": {}, "outputs": [], @@ -1046,7 +1257,7 @@ " \n", " \n", " \n", - " 2023-02-23T21:00:24.300243\n", + " 2023-04-03T19:27:25.859227\n", " image/svg+xml\n", " \n", " \n", @@ -1082,16 +1293,16 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1128,11 +1339,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1163,11 +1374,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1208,11 +1419,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1263,16 +1474,16 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1298,11 +1509,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1340,11 +1551,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1393,11 +1604,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1455,11 +1666,11 @@ " \n", " \n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1504,7 +1715,7 @@ "L 210.643838 138.512248 \n", "L 216.766095 139.018898 \n", "L 222.888352 139.5 \n", - "\" clip-path=\"url(#pc8a5adcc1c)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", + "\" clip-path=\"url(#pa44c91e5d2)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1545,7 +1756,7 @@ ], "source": [ "scheduler = SquareRootScheduler(lr=0.1)\n", - "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" + "d2l.plot(d2l.arange(num_epochs).numpy(), [scheduler(t) for t in range(num_epochs)])" ] }, { @@ -1558,10 +1769,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "71bbba03", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss 0.236, train acc 0.914, test acc 0.892\n" + ] + }, { "data": { "image/svg+xml": [ @@ -1573,7 +1791,7 @@ " \n", " \n", " \n", - " 2023-02-23T21:04:10.181314\n", + " 2023-04-03T19:31:16.589334\n", " image/svg+xml\n", " \n", " \n", @@ -1609,16 +1827,16 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1655,11 +1873,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1700,11 +1918,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1735,11 +1953,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1754,11 +1972,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1799,11 +2017,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1818,11 +2036,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1997,16 +2215,16 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2031,11 +2249,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2051,11 +2269,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2071,11 +2289,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2091,11 +2309,11 @@ " \n", " \n", + "\" clip-path=\"url(#pe74770e75b)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2109,109 +2327,283 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -2484,83 +2876,10 @@ "#### 11.11.3.1. 单因子调度器" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "c271bc7a", - "metadata": {}, - "outputs": [], - "source": [ - "class FactorScheduler:\n", - " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", - " self.factor = factor\n", - " self.stop_factor_lr = stop_factor_lr\n", - " self.base_lr = base_lr\n", - "\n", - " def __call__(self, num_update):\n", - " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", - " return self.base_lr\n", - "\n", - "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", - "d2l.plot(d2l.arange(50), [scheduler(t) for t in range(50)])" - ] - }, - { - "cell_type": "markdown", - "id": "eb5f8ed5", - "metadata": {}, - "source": [ - "#### 11.11.3.2. 多因子调度器" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0828bec8", - "metadata": {}, - "outputs": [], - "source": [ - "class MultiFactorScheduler:\n", - " def __init__(self, step, factor, base_lr):\n", - " self.step = step\n", - " self.factor = factor\n", - " self.base_lr = base_lr\n", - "\n", - " def __call__(self, epoch):\n", - " if epoch in self.step:\n", - " self.base_lr = self.base_lr * self.factor\n", - " return self.base_lr\n", - " else:\n", - " return self.base_lr\n", - "\n", - "scheduler = MultiFactorScheduler(step=[15, 30], factor=0.5, base_lr=0.5)\n", - "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "111352a0", - "metadata": {}, - "outputs": [], - "source": [ - "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", - "trainer = nn.SGD(net.trainable_params(), lr_list)\n", - "train(net, train_iter, test_iter, num_epochs, loss, trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "4c27a8f8", - "metadata": {}, - "source": [ - "#### 11.11.3.3. 余弦调度器" - ] - }, { "cell_type": "code", "execution_count": 11, - "id": "8e8a5c3e", + "id": "c271bc7a", "metadata": {}, "outputs": [ { @@ -2569,12 +2888,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:05.292287\n", + " 2023-04-03T19:31:16.820833\n", " image/svg+xml\n", " \n", " \n", @@ -2590,41 +2909,41 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class FactorScheduler:\n", + " def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):\n", + " self.factor = factor\n", + " self.stop_factor_lr = stop_factor_lr\n", + " self.base_lr = base_lr\n", + "\n", + " def __call__(self, num_update):\n", + " self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)\n", + " return self.base_lr\n", + "\n", + "scheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)\n", + "d2l.plot(d2l.arange(50).numpy(), [scheduler(t) for t in range(50)])" + ] + }, + { + "cell_type": "markdown", + "id": "eb5f8ed5", + "metadata": {}, + "source": [ + "#### 11.11.3.2. 多因子调度器" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0828bec8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T19:31:17.029865\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class MultiFactorScheduler:\n", + " def __init__(self, step, factor, base_lr):\n", + " self.step = step\n", + " self.factor = factor\n", + " self.base_lr = base_lr\n", + "\n", + " def __call__(self, epoch):\n", + " if epoch in self.step:\n", + " self.base_lr = self.base_lr * self.factor\n", + " return self.base_lr\n", + " else:\n", + " return self.base_lr\n", + "\n", + "scheduler = MultiFactorScheduler(step=[15, 30], factor=0.5, base_lr=0.5)\n", + "d2l.plot(d2l.arange(num_epochs).numpy(), [scheduler(t) for t in range(num_epochs)])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "111352a0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train loss 0.175, train acc 0.934, test acc 0.892\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T19:35:11.584281\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -3018,15 +5000,23 @@ } ], "source": [ - "# 缺少 warmup, 所以自定义 CosineScheduler\n", - "scheduler = nn.CosineDecayLR(min_lr=0.01, max_lr=0.3, decay_steps=20)\n", - "scheduler(d2l.tensor(0))\n", - "d2l.plot(d2l.arange(num_epochs), scheduler(d2l.tensor([t for t in range(num_epochs)])))" + "net = net_fn()\n", + "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", + "trainer = nn.SGD(net.trainable_params(), lr_list)\n", + "train(net, train_iter, test_iter, num_epochs, loss, trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "4c27a8f8", + "metadata": {}, + "source": [ + "#### 11.11.3.3. 余弦调度器" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "60ba8445", "metadata": {}, "outputs": [ @@ -3036,12 +5026,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:05.569894\n", + " 2023-04-03T19:35:11.857681\n", " image/svg+xml\n", " \n", " \n", @@ -3057,41 +5047,41 @@ " \n", " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -3510,12 +5435,12 @@ " return self.base_lr\n", "\n", "scheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)\n", - "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" + "d2l.plot(d2l.arange(num_epochs).numpy(), [scheduler(t) for t in range(num_epochs)])" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "5312a7d2", "metadata": {}, "outputs": [ @@ -3523,7 +5448,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train loss 3089455025148719102164992.000, train acc 0.135, test acc 0.142\n" + "train loss 0.158, train acc 0.943, test acc 0.897\n" ] }, { @@ -3532,12 +5457,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:57.717367\n", + " 2023-04-03T19:39:10.365130\n", " image/svg+xml\n", " \n", " \n", @@ -3553,8 +5478,8 @@ " \n", " \n", " \n", @@ -3573,16 +5498,16 @@ " \n", " \n", + "\" clip-path=\"url(#p9b45064230)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3617,18 +5542,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3956,22 +5883,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p9b45064230)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -3990,17 +5917,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p9b45064230)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4010,17 +5937,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p9b45064230)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4030,17 +5957,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p9b45064230)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4050,17 +5977,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p9b45064230)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4070,39 +5997,284 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4303,13 +6475,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4323,13 +6495,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4346,7 +6518,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4362,7 +6534,6 @@ ], "source": [ "net = net_fn()\n", - "\n", "lr_list = d2l.tensor([scheduler(t) for t in range(num_epochs) for i in range(steps_per_epoch)])\n", "trainer = nn.SGD(net.trainable_params(), lr_list)\n", "train(net, train_iter, test_iter, num_epochs, loss, trainer)" @@ -4378,7 +6549,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "ee8a5ba8", "metadata": {}, "outputs": [ @@ -4388,12 +6559,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:04:57.990175\n", + " 2023-04-03T19:39:10.648130\n", " image/svg+xml\n", " \n", " \n", @@ -4408,42 +6579,42 @@ " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "\n" @@ -4857,12 +6944,12 @@ ], "source": [ "scheduler = CosineScheduler(20, warmup_steps=5, base_lr=0.3, final_lr=0.01)\n", - "d2l.plot(d2l.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])" + "d2l.plot(d2l.arange(num_epochs).numpy(), [scheduler(t) for t in range(num_epochs)])" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "id": "05973299", "metadata": {}, "outputs": [ @@ -4870,7 +6957,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "train loss 66452015380.343, train acc 0.100, test acc 0.100\n" + "train loss 0.155, train acc 0.943, test acc 0.897\n" ] }, { @@ -4879,12 +6966,12 @@ "\n", "\n", - "\n", + "\n", " \n", " \n", " \n", " \n", - " 2023-02-22T23:05:52.549844\n", + " 2023-04-03T19:43:01.866532\n", " image/svg+xml\n", " \n", " \n", @@ -4900,8 +6987,8 @@ " \n", " \n", " \n", @@ -4920,16 +7007,16 @@ " \n", " \n", + "\" clip-path=\"url(#p48ced7d4e6)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -4964,18 +7051,63 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5303,22 +7392,22 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p48ced7d4e6)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5337,17 +7426,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p48ced7d4e6)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5357,17 +7446,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p48ced7d4e6)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5377,17 +7466,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p48ced7d4e6)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5397,17 +7486,17 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", + "\" clip-path=\"url(#p48ced7d4e6)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5417,44 +7506,281 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5655,13 +7981,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5675,13 +8001,13 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5698,7 +8024,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -5746,6 +8072,19 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/chapter_13_chapter_computer-vision/ssd.ipynb b/chapter_13_chapter_computer-vision/ssd.ipynb index 80d6164..86c9687 100644 --- a/chapter_13_chapter_computer-vision/ssd.ipynb +++ b/chapter_13_chapter_computer-vision/ssd.ipynb @@ -29,7 +29,15 @@ "execution_count": 1, "id": "e6901d6d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(300375:47194544307584,MainProcess):2023-04-03-16:15:37.357.548 [mindspore/common/_decorator.py:40] 'Fills' is deprecated from version 2.0 and will be removed in a future version, use 'ops.fill' instead.\n" + ] + } + ], "source": [ "%matplotlib inline\n", "import mindspore\n", @@ -103,7 +111,7 @@ "outputs": [], "source": [ "def flatten_pred(pred):\n", - " return d2l.flatten(pred.permute(0, 2, 3, 1)) # flatten不改变0轴的size\n", + " return d2l.flatten(pred.permute(0, 2, 3, 1))\n", "\n", "def concat_preds(preds):\n", " return d2l.concat([flatten_pred(p) for p in preds], axis=1)" @@ -250,7 +258,7 @@ "source": [ "def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):\n", " Y = blk(X)\n", - " anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio) ##### 需要第四章\n", + " anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)\n", " cls_preds = cls_predictor(Y)\n", " bbox_preds = bbox_predictor(Y)\n", " return (Y, anchors, cls_preds, bbox_preds)" @@ -363,7 +371,7 @@ ], "source": [ "batch_size = 32\n", - "train_iter, _ = d2l.load_data_bananas(batch_size) # " + "train_iter, _ = d2l.load_data_bananas(batch_size)" ] }, { @@ -413,318 +421,3147 @@ "source": [ "def cls_eval(cls_preds, cls_labels):\n", " # 由于类别预测结果放在最后一维,argmax需要指定最后一维。\n", - " return float((cls_preds.argmax(axis=-1).type(\n", + " return float((cls_preds.argmax(axis=-1).astype(\n", " cls_labels.dtype) == cls_labels).sum())\n", "\n", "def bbox_eval(bbox_preds, bbox_labels, bbox_masks):\n", - "# return float((d2l.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())\n", " return float(((bbox_labels - bbox_preds) * bbox_masks).abs().sum())" ] }, { - "cell_type": "code", - "execution_count": 18, - "id": "259d2abc", + "cell_type": "markdown", + "id": "6912bb2d", "metadata": {}, - "outputs": [], "source": [ - "from mindspore import ops\n", - "def box_iou(boxes1, boxes2):\n", - " \"\"\"计算两个锚框或边界框列表中成对的交并比\"\"\"\n", - " box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n", - " (boxes[:, 3] - boxes[:, 1]))\n", - " # boxes1,boxes2,areas1,areas2的形状:\n", - " # boxes1:(boxes1的数量,4),\n", - " # boxes2:(boxes2的数量,4),\n", - " # areas1:(boxes1的数量,),\n", - " # areas2:(boxes2的数量,)\n", - " areas1 = box_area(boxes1)\n", - " areas2 = box_area(boxes2)\n", - " # inter_upperlefts,inter_lowerrights,inters的形状:\n", - " # (boxes1的数量,boxes2的数量,2)\n", - " inter_upperlefts = ops.maximum(boxes1[:, None, :2], boxes2[:, :2])\n", - " inter_lowerrights = ops.minimum(boxes1[:, None, 2:], boxes2[:, 2:])\n", - " inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n", - " # inter_areasandunion_areas的形状:(boxes1的数量,boxes2的数量)\n", - " inter_areas = inters[:, :, 0] * inters[:, :, 1]\n", - " union_areas = areas1[:, None] + areas2 - inter_areas\n", - " return inter_areas / union_areas\n", - "\n", - "def assign_anchor_to_bbox(ground_truth, anchors, iou_threshold=0.5):\n", - " \"\"\"将最接近的真实边界框分配给锚框\"\"\"\n", - " num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n", - " # 位于第i行和第j列的元素x_ij是锚框i和真实边界框j的IoU\n", - " jaccard = box_iou(anchors, ground_truth)\n", - " # 对于每个锚框,分配的真实边界框的张量\n", - " anchors_bbox_map = ops.full((num_anchors,), -1, dtype=mindspore.int64)\n", - " # 根据阈值,决定是否分配真实边界框\n", - " indices, max_ious = ops.max(jaccard, axis=1)\n", - " anc_i = ops.nonzero(max_ious >= iou_threshold).reshape(-1)\n", - " box_j = ops.masked_select(indices, max_ious >= iou_threshold)\n", - " anchors_bbox_map[anc_i] = box_j \n", - " col_discard = ops.full((num_anchors,), -1)\n", - " row_discard = ops.full((num_gt_boxes,), -1)\n", - " for _ in range(num_gt_boxes):\n", - " max_idx = ops.argmax(jaccard)\n", - " box_idx = (max_idx % num_gt_boxes).long()\n", - " anc_idx = (max_idx / num_gt_boxes).long()\n", - " anchors_bbox_map[anc_idx] = box_idx\n", - " jaccard[:, box_idx] = col_discard\n", - " jaccard[anc_idx, :] = row_discard\n", - " return anchors_bbox_map\n", - "\n", - "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n", - " \"\"\"对锚框偏移量的转换\"\"\"\n", - " c_anc = d2l.box_corner_to_center(anchors)\n", - " c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n", - " offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n", - " offset_wh = 5 * ops.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n", - " offset = ops.concat([offset_xy, offset_wh], axis=1)\n", - " return offset\n", - "\n", - "def multibox_target(anchors, labels):\n", - " \"\"\"使用真实边界框标记锚框\"\"\"\n", - " batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n", - " batch_offset, batch_mask, batch_class_labels = [], [], []\n", - " num_anchors = anchors.shape[0]\n", - " for i in range(batch_size):\n", - " label = labels[i, :, :]\n", - " anchors_bbox_map = assign_anchor_to_bbox(\n", - " label[:, 1:], anchors)\n", - " bbox_mask = ops.tile((anchors_bbox_map >= 0).float().unsqueeze(-1), (1, 4))\n", - " # 将类标签和分配的边界框坐标初始化为零\n", - " class_labels = ops.zeros(num_anchors, dtype=mindspore.int32)\n", - " assigned_bb = ops.zeros((num_anchors, 4), dtype=mindspore.float32)\n", - " # 使用真实边界框来标记锚框的类别。\n", - " # 如果一个锚框没有被分配,标记其为背景(值为零)\n", - " indices_true = ops.nonzero(anchors_bbox_map >= 0)\n", - " bb_idx = anchors_bbox_map[indices_true]\n", - " class_labels[indices_true] = label[bb_idx, 0].long() + 1 \n", - " assigned_bb[indices_true] = label[bb_idx, 1:]\n", - " # 偏移量转换\n", - " offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n", - " batch_offset.append(offset.reshape(-1))\n", - " batch_mask.append(bbox_mask.reshape(-1))\n", - " batch_class_labels.append(class_labels)\n", - " bbox_offset = ops.stack(batch_offset)\n", - " bbox_mask = ops.stack(batch_mask)\n", - " class_labels = ops.stack(batch_class_labels)\n", - " return (bbox_offset, bbox_mask, class_labels)\n", - "\n", - "# print(anchors.shape, Y.shape) # (1, 5444, 4) (32, 1, 5)\n", - "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)" + "#### 13.7.2.3. 训练模型" ] }, { "cell_type": "code", "execution_count": 19, - "id": "d80ea58d", + "id": "3870f282", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(32, 3, 256, 256) (32, 1, 5)\n" + "class err 2.85e-03, bbox mae 2.70e-03\n", + "6.8 examples/sec on \n" ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:09:36.597677\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "#forward\n", - "for X, Y in train_iter:\n", - " print(X.shape, Y.shape)\n", - " break" + "num_epochs, timer = 20, d2l.Timer()\n", + "animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", + " legend=['class error', 'bbox mae'])\n", + "\n", + "def forward_fn(X, Y):\n", + " # 生成多尺度的锚框,为每个锚框预测类别和偏移量\n", + " anchors, cls_preds, bbox_preds = net(X)\n", + " # 为每个锚框标注类别和偏移量\n", + " bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)\n", + " # 根据类别和偏移量的预测和标注值计算损失函数\n", + " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", + " bbox_masks).mean()\n", + " return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", + " \n", + "grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True)\n", + " \n", + "def train_step(inputs, targets):\n", + " (l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels), grads = grad_fn(inputs, targets)\n", + " trainer(grads)\n", + " return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", + " \n", + "for epoch in range(num_epochs):\n", + " # 训练精确度的和,训练精确度的和中的示例数\n", + " # 绝对误差的和,绝对误差的和中的示例数\n", + " metric = d2l.Accumulator(4)\n", + " net.set_train()\n", + " for features, target in train_iter:\n", + " timer.start()\n", + " l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels = train_step(features, target)\n", + " metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),\n", + " bbox_eval(bbox_preds, bbox_labels, bbox_masks),\n", + " bbox_labels.numel())\n", + " cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]\n", + " animator.add(epoch + 1, (cls_err, bbox_mae))\n", + "print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')\n", + "print(f'{train_iter.get_dataset_size()/ timer.stop():.1f} examples/sec on ')" ] }, { - "cell_type": "code", - "execution_count": 20, - "id": "ca1c0d14", + "cell_type": "markdown", + "id": "425916e0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "output anchors: (1, 5444, 4)\n", - "output class preds: (32, 5444, 2)\n", - "output bbox preds: (32, 21776)\n" - ] - } - ], "source": [ - "anchors, cls_preds, bbox_preds = net(X)\n", - "print('output anchors:', anchors.shape)\n", - "print('output class preds:', cls_preds.shape)\n", - "print('output bbox preds:', bbox_preds.shape)" + "### 13.7.3. 预测目标" ] }, { "cell_type": "code", - "execution_count": 21, - "id": "d1f7601f", + "execution_count": 20, + "id": "a3b40a56", "metadata": { "scrolled": true }, + "outputs": [], + "source": [ + "img = mindspore.dataset.vision.read_image('../img/banana.jpg').astype('float32')\n", + "X = mindspore.Tensor(img.transpose(2,0,1), dtype=mindspore.float32).unsqueeze(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ab17d783", + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.896 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n", - "[WARNING] KERNEL(269803,2aecc71afd80,python):2023-02-23-20:56:28.763.965 [mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h:51] CalShapesSizeInBytes] For 'Argmax', the shapes[0] is ( )\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bbox_labels: (32, 21776) (32, 21776)\n", - "bbox_masks: (32, 21776) Float32\n", - "cls_labels: (32, 5444) Int32\n" + "[WARNING] KERNEL(300375,2aec9fa87700,python):2023-04-03-17:11:46.720.731 [mindspore/ccsrc/kernel/kernel.h:514] CheckShapeNull] For 'Reshape', the shape of input cannot contain zero, but got (0, 1)\n", + "[WARNING] KERNEL(300375,2aec9fa87700,python):2023-04-03-17:11:46.723.968 [mindspore/ccsrc/kernel/kernel.h:514] CheckShapeNull] For 'Add', the shape of input_0 cannot contain zero, but got (0)\n", + "[WARNING] KERNEL(300375,2aec9fa87700,python):2023-04-03-17:11:46.734.618 [mindspore/ccsrc/kernel/kernel.h:514] CheckShapeNull] For 'Gather', the shape of indices cannot contain zero, but got (0)\n", + "<__array_function__ internals>:200: RuntimeWarning: invalid value encountered in cast\n" ] } ], "source": [ - "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - "print('cls_labels:', cls_labels.shape, cls_labels.dtype)" + "def predict(X):\n", + " net.set_train(False)\n", + " anchors, cls_preds, bbox_preds = net(X)\n", + " cls_probs = d2l.softmax(cls_preds, axis=2).permute(0, 2, 1)\n", + " output = d2l.multibox_detection(cls_probs, bbox_preds, anchors) # d2l.\n", + " idx = [i for i, row in enumerate(output[0]) if row[0] != -1]\n", + " return output[0][idx]\n", + "\n", + "output = predict(X)" ] }, { "cell_type": "code", - "execution_count": 22, - "id": "993302d8", + "execution_count": 26, + "id": "7c7d41f6", "metadata": {}, "outputs": [ { "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:18:28.618567\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], "text/plain": [ - "Tensor(shape=[], dtype=Float32, value= 0.69856)" + "
" ] }, - "execution_count": 22, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks).mean()\n", - "l" + "def display(img, output, threshold):\n", + " d2l.set_figsize((5, 5))\n", + " fig = d2l.plt.imshow(img)\n", + " for row in output:\n", + " score = float(row[1])\n", + " if score < threshold:\n", + " continue\n", + " h, w = img.shape[0:2]\n", + " bbox = [row[2:6] * d2l.tensor((w, h, w, h))]\n", + " d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')\n", + "img = X.numpy().squeeze().transpose(1,2,0)/255\n", + "display(img, output, threshold=0.9)" ] }, { "cell_type": "markdown", - "id": "9c78354a", + "id": "a49128d0", "metadata": {}, "source": [ - "# BUG 直接运行的话,这里会卡住" + "### 13.7.5. 练习" ] }, { "cell_type": "code", - "execution_count": null, - "id": "5e0d0317", + "execution_count": 27, + "id": "77b6ea35", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:18:32.085257\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#forward\n", - "for X, Y in train_iter:\n", - " print(X.shape, Y.shape)\n", - " anchors, cls_preds, bbox_preds = net(X)\n", - " print('output anchors:', anchors.shape)\n", - " print('output class preds:', cls_preds.shape)\n", - " print('output bbox preds:', bbox_preds.shape)\n", + "import numpy as np\n", + "def smooth_l1(data, scalar):\n", + " out = []\n", + " for i in data:\n", + " if abs(i) < 1 / (scalar ** 2):\n", + " out.append(((scalar * i) ** 2) / 2)\n", + " else:\n", + " out.append(abs(i) - 0.5 / (scalar ** 2))\n", " \n", - " bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - " print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - " print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - " print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", - " \n", - " l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)\n", - " print(l)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ef20694", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y = next(train_iter)\n", - "print(X.shape, Y.shape)\n", - "anchors, cls_preds, bbox_preds = net(X)\n", - "print('output anchors:', anchors.shape)\n", - "print('output class preds:', cls_preds.shape)\n", - "print('output bbox preds:', bbox_preds.shape)\n", + " return np.array(out)\n", "\n", - "bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)\n", - "print('bbox_labels:', bbox_labels.shape, bbox_labels.shape)\n", - "print('bbox_masks:', bbox_masks.shape, bbox_masks.dtype)\n", - "print('cls_labels:', cls_labels.shape, cls_labels.dtype) \n", + "sigmas = [10, 1, 0.5]\n", + "lines = ['-', '--', '-.']\n", + "x = d2l.arange(-2, 2, 0.1).numpy()\n", + "d2l.set_figsize()\n", "\n", - "l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - " bbox_masks)\n", - "print(l)\n" - ] - }, - { - "cell_type": "markdown", - "id": "6912bb2d", - "metadata": {}, - "source": [ - "#### 13.7.2.3. 训练模型" + "for l, s in zip(lines, sigmas):\n", + " y = smooth_l1(x, scalar=s)\n", + " \n", + " d2l.plt.plot(x, y, l, label='sigma=%.1f' % s)\n", + "d2l.plt.legend();" ] }, { "cell_type": "code", - "execution_count": null, - "id": "3870f282", + "execution_count": 28, + "id": "289f5472", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-03T17:18:32.544692\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.6.3, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# num_epochs, timer = 20, d2l.Timer()\n", - "# animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", - "# legend=['class error', 'bbox mae'])\n", + "def focal_loss(gamma, x):\n", + " return -(1 - x) ** gamma * np.log(x)\n", "\n", - "# def forward_fn(X, Y):\n", - "# # 生成多尺度的锚框,为每个锚框预测类别和偏移量\n", - "# anchors, cls_preds, bbox_preds = net(X)\n", - "# # 为每个锚框标注类别和偏移量\n", - "# bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y) # d2l.\n", - "# # 根据类别和偏移量的预测和标注值计算损失函数\n", - "# l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,\n", - "# bbox_masks).mean()\n", - "# return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", - " \n", - "# grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True)\n", - " \n", - "# def train_step(inputs, targets):\n", - "# (l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels), grads = grad_fn(inputs, targets)\n", - "# trainer(grads)\n", - "# return l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels\n", - " \n", - "# for epoch in range(num_epochs):\n", - "# # 训练精确度的和,训练精确度的和中的示例数\n", - "# # 绝对误差的和,绝对误差的和中的示例数\n", - "# metric = d2l.Accumulator(4)\n", - "# net.set_train()\n", - "# for features, target in train_iter:\n", - "# timer.start()\n", - "# print(epoch)\n", - "# l, cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks, bbox_labels = train_step(features, target)\n", - "# metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),\n", - "# bbox_eval(bbox_preds, bbox_labels, bbox_masks),\n", - "# bbox_labels.numel())\n", - "# cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]\n", - "# animator.add(epoch + 1, (cls_err, bbox_mae))\n", - "# print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')\n", - "# print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on '\n", - "# f'{str(device)}')" + "x = np.arange(0.01, 1, 0.01)\n", + "for l, gamma in zip(lines, [0, 1, 5]):\n", + " y = d2l.plt.plot(x, focal_loss(gamma, x), l, label='gamma=%.1f' % gamma)\n", + "d2l.plt.legend();" ] } ], diff --git a/d2l/mindspore.py b/d2l/mindspore.py index 7b9523b..5a30d75 100644 --- a/d2l/mindspore.py +++ b/d2l/mindspore.py @@ -2485,50 +2485,6 @@ def resnet_block(in_channels, out_channels, num_residuals, return net -def train_batch_ch13(train_step, X, y): - """用GPU进行小批量训练""" - l, pred = train_step(X, y) - train_loss_sum = l.sum() - train_acc_sum = d2l.accuracy(pred, y) - return train_loss_sum, train_acc_sum - -def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs): - """用GPU进行模型训练""" - timer, num_batches = d2l.Timer(), train_iter.get_dataset_size() - animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], - legend=['train loss', 'train acc', 'test acc']) - - def forward_fn(inputs, targets): - logits = net(inputs) - l = loss(logits, targets) - return l, logits - - grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True) - - def train_step(inputs, targets): - (l, logits), grads = grad_fn(inputs, targets) - trainer(grads) - return l, logits - - for epoch in range(num_epochs): - # 4个维度:储存训练损失,训练准确度,实例数,特点数 - metric = d2l.Accumulator(4) - net.set_train() - for i, (features, labels) in enumerate(train_iter): - timer.start() - l, acc = train_batch_ch13( - train_step, features, labels) # , loss, trainer, devices - metric.add(l, acc, labels.shape[0], labels.numel()) - timer.stop() - if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: - animator.add(epoch + (i + 1) / num_batches, - (metric[0] / metric[2], metric[1] / metric[3], - None)) - test_acc = d2l.evaluate_accuracy_gpu(net, test_iter) - animator.add(epoch + 1, (None, None, test_acc)) - print(f'loss {metric[0] / metric[2]:.3f}, train acc ' - f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}') - print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec') def multibox_prior(data, sizes, ratios): """生成以每个像素为中心具有不同形状的锚框""" @@ -2648,7 +2604,7 @@ def assign_anchor_to_bbox(ground_truth, anchors, iou_threshold=0.5): # 对于每个锚框,分配的真实边界框的张量 anchors_bbox_map = ops.full((num_anchors,), -1, dtype=mindspore.int64) # 根据阈值,决定是否分配真实边界框 - indices, max_ious = ops.max(jaccard, axis=1) + max_ious, indices = ops.max(jaccard, axis=1) anc_i = ops.nonzero(max_ious >= iou_threshold).reshape(-1) box_j = ops.masked_select(indices, max_ious >= iou_threshold) anchors_bbox_map[anc_i] = box_j @@ -2683,14 +2639,14 @@ def multibox_target(anchors, labels): label[:, 1:], anchors) bbox_mask = ops.tile((anchors_bbox_map >= 0).float().unsqueeze(-1), (1, 4)) # 将类标签和分配的边界框坐标初始化为零 - class_labels = ops.zeros(num_anchors, dtype=mindspore.int32) + class_labels = ops.zeros(num_anchors, dtype=mindspore.int64) assigned_bb = ops.zeros((num_anchors, 4), dtype=mindspore.float32) # 使用真实边界框来标记锚框的类别。 # 如果一个锚框没有被分配,标记其为背景(值为零) indices_true = ops.nonzero(anchors_bbox_map >= 0) bb_idx = anchors_bbox_map[indices_true] - class_labels[indices_true] = label[bb_idx, 0].long() + 1 - assigned_bb[indices_true] = label[bb_idx, 1:] + class_labels[indices_true] = label[:, 0][bb_idx].long() + 1 + assigned_bb[indices_true] = label[:, 1:][bb_idx] # 偏移量转换 offset = offset_boxes(anchors, assigned_bb) * bbox_mask batch_offset.append(offset.reshape(-1)) @@ -2698,7 +2654,7 @@ def multibox_target(anchors, labels): batch_class_labels.append(class_labels) bbox_offset = ops.stack(batch_offset) bbox_mask = ops.stack(batch_mask) - class_labels = ops.stack(batch_class_labels) + class_labels = ops.stack(batch_class_labels).astype('int32') return (bbox_offset, bbox_mask, class_labels) d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip', @@ -2746,6 +2702,70 @@ def load_data_bananas(batch_size): return train_iter, val_iter +def offset_inverse(anchors, offset_preds): + """根据带有预测偏移量的锚框来预测边界框""" + anc = d2l.box_corner_to_center(anchors) + pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2] + pred_bbox_wh = ops.exp(offset_preds[:, 2:] / 5) * anc[:, 2:] + pred_bbox = ops.concat((pred_bbox_xy, pred_bbox_wh), axis=1) + predicted_bbox = d2l.box_center_to_corner(pred_bbox) + return predicted_bbox + + +def nms(boxes, scores, iou_threshold): + """对预测边界框的置信度进行排序""" + B = ops.argsort(scores, axis=-1, descending=True) + keep = [] # 保留预测边界框的指标 + while B.numel() > 0: + i = B[0] + keep.append(i.asnumpy()) + if B.numel() == 1: break + iou = box_iou(boxes[i].reshape(-1, 4), + boxes[:, :][B[1:]].reshape(-1, 4)).reshape(-1) + inds = ops.nonzero(iou <= iou_threshold).reshape(-1) + B = B[inds + 1] + return mindspore.Tensor(keep) + + +def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5, + pos_threshold=0.009999999): + """使用非极大值抑制来预测边界框""" + batch_size = cls_probs.shape[0] + anchors = anchors.squeeze(0) + num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2] + out = [] + for i in range(batch_size): + cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4) + conf, class_id = ops.max(cls_prob[1:], 0) + predicted_bb = offset_inverse(anchors, offset_pred) + keep = nms(predicted_bb, conf, nms_threshold) + # 找到所有的non_keep索引,并将类设置为背景 + all_idx = ops.arange(num_anchors, dtype=mindspore.int64) + combined = ops.concat((keep, all_idx)) + unique = ops.unique(combined) + uniques, _ = unique[0], unique[1] + # 统计未重复的,mindspore的unique无法统计同一个元素出现次数 + counts = mnp.bincount(combined) + non_keep = mindspore.Tensor([int(uniques[i]) for i in range(len(counts)) + if (counts[i] == 1)]) + all_id_sorted = ops.concat((keep, non_keep)) + class_id[non_keep] = -1 + class_id = class_id[all_id_sorted] + conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted] + # pos_threshold是一个用于非背景预测的阈值 + below_min_idx = (conf < pos_threshold) + class_id = mindspore.Tensor.float(class_id) + class_id[below_min_idx] = -1 + for j in range(len(conf)): + if below_min_idx[j] == True: + conf[j] = 1 - conf[j] + + pred_info = ops.concat((class_id.unsqueeze(1), + conf.unsqueeze(1).astype('float32'), + predicted_bb), axis=1) + out.append(pred_info) + return ops.stack(out) + d2l.DATA_HUB['voc2012'] = (d2l.DATA_URL + 'VOCtrainval_11-May-2012.tar', '4e443f8a2eca6b1dac8a6c57641b67dd40621a49') @@ -3016,6 +3036,7 @@ def get_pooled_rois(self,feature_map, roi_batch): square = ops.square sqrt = ops.sqrt sign = ops.sign +softmax = ops.softmax meshgrid = ops.meshgrid linspace = ops.linspace zeros_like = ops.zeros_like