Skip to content

Commit

Permalink
Add cpp folder for C++ frontend examples (pytorch#492)
Browse files Browse the repository at this point in the history
* Create C++ version of MNIST example

* Create C++ version of DCGAN example

* Update for Normalize transform
  • Loading branch information
goldsborough authored and soumith committed Jan 15, 2019
1 parent 5d27fdb commit 29a38c6
Show file tree
Hide file tree
Showing 10 changed files with 678 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ dcgan/data
data
*.pyc
OpenNMT/data
cpp/mnist/build
cpp/dcgan/build
88 changes: 88 additions & 0 deletions cpp/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
20 changes: 20 additions & 0 deletions cpp/dcgan/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)

find_package(Torch REQUIRED)

option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)
message(STATUS "Downloading MNIST dataset")
execute_process(
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py
-d ${CMAKE_BINARY_DIR}/data
ERROR_VARIABLE DOWNLOAD_ERROR)
if (DOWNLOAD_ERROR)
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
endif()
endif()

add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 11)
56 changes: 56 additions & 0 deletions cpp/dcgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# DCGAN Example with the PyTorch C++ Frontend

This folder contains an example of training a DCGAN to generate MNIST digits
with the PyTorch C++ frontend.

The entire training code is contained in `dcgan.cpp`.

To build the code, run the following commands from your terminal:

```shell
$ cd dcgan
$ mkdir build
$ cd build
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
$ make
```

where `/path/to/libtorch` should be the path to the unzipped *LibTorch*
distribution, which you can get from the [PyTorch
homepage](https://pytorch.org/get-started/locally/).

Execute the compiled binary to train the model:

```shell
$ ./dcgan
[ 1/30][200/938] D_loss: 0.4953 | G_loss: 4.0195
-> checkpoint 1
[ 1/30][400/938] D_loss: 0.3610 | G_loss: 4.8148
-> checkpoint 2
[ 1/30][600/938] D_loss: 0.4072 | G_loss: 4.36760
-> checkpoint 3
[ 1/30][800/938] D_loss: 0.4444 | G_loss: 4.0250
-> checkpoint 4
[ 2/30][200/938] D_loss: 0.3761 | G_loss: 3.8790
-> checkpoint 5
[ 2/30][400/938] D_loss: 0.3977 | G_loss: 3.3315
-> checkpoint 6
[ 2/30][600/938] D_loss: 0.3815 | G_loss: 3.5696
-> checkpoint 7
[ 2/30][800/938] D_loss: 0.4039 | G_loss: 3.2759
-> checkpoint 8
[ 3/30][200/938] D_loss: 0.4236 | G_loss: 4.5132
-> checkpoint 9
[ 3/30][400/938] D_loss: 0.3645 | G_loss: 3.9759
-> checkpoint 10
...
```

The training script periodically generates image samples. Use the
`display_samples.py` script situated in this folder to generate a plot image.
For example:

```shell
$ python display_samples.py -i dcgan-sample-10.png
Saved out.png
```
187 changes: 187 additions & 0 deletions cpp/dcgan/dcgan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#include <torch/torch.h>

#include <cmath>
#include <cstdio>
#include <iostream>

// The size of the noise vector fed to the generator.
const int64_t kNoiseSize = 100;

// The batch size for training.
const int64_t kBatchSize = 64;

// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;

// Where to find the MNIST dataset.
const char* kDataFolder = "./data";

// After how many batches to create a new checkpoint periodically.
const int64_t kCheckpointEvery = 200;

// How many images to sample at every checkpoint.
const int64_t kNumberOfSamplesPerCheckpoint = 10;

// Set to `true` to restore models and optimizers from previously saved
// checkpoints.
const bool kRestoreFromCheckpoint = false;

// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;

using namespace torch;

int main(int argc, const char* argv[]) {
torch::manual_seed(1);

// Create the device we pass around based on whether CUDA is available.
torch::Device device(torch::kCPU);
if (torch::cuda::is_available()) {
std::cout << "CUDA is available! Training on GPU." << std::endl;
device = torch::Device(torch::kCUDA);
}

nn::Sequential generator(
// Layer 1
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(256),
nn::Functional(torch::relu),
// Layer 2
nn::Conv2d(nn::Conv2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(128),
nn::Functional(torch::relu),
// Layer 3
nn::Conv2d(nn::Conv2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(64),
nn::Functional(torch::relu),
// Layer 4
nn::Conv2d(nn::Conv2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::Functional(torch::tanh));
generator->to(device);

nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
nn::Functional(torch::leaky_relu, 0.2),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(128),
nn::Functional(torch::leaky_relu, 0.2),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(256),
nn::Functional(torch::leaky_relu, 0.2),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
nn::Functional(torch::sigmoid));
discriminator->to(device);

// Assume the MNIST dataset is available under `kDataFolder`;
auto dataset = torch::data::datasets::MNIST(kDataFolder)
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch =
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

auto data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));

if (kRestoreFromCheckpoint) {
torch::load(generator, "generator-checkpoint.pt");
torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::load(discriminator, "discriminator-checkpoint.pt");
torch::load(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
}

int64_t checkpoint_counter = 1;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data.to(device);
torch::Tensor real_labels =
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real =
torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();

// Train discriminator with fake images.
torch::Tensor noise =
torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake =
torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();

torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();

// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss =
torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();

if (batch_index % kLogInterval == 0) {
std::printf(
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
epoch,
kNumberOfEpochs,
++batch_index,
batches_per_epoch,
d_loss.item<float>(),
g_loss.item<float>());
}

if (batch_index % kCheckpointEvery == 0) {
// Checkpoint the model and optimizer state.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// Sample the generator and save the images.
torch::Tensor samples = generator->forward(torch::randn(
{kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device));
torch::save(
(samples + 1.0) / 2.0,
torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
}
}

std::cout << "Training complete!" << std::endl;
}
28 changes: 28 additions & 0 deletions cpp/dcgan/display_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import print_function
from __future__ import unicode_literals

import argparse

import matplotlib.pyplot as plt
import torch


parser = argparse.ArgumentParser()
parser.add_argument("-i", "--sample-file", required=True)
parser.add_argument("-o", "--out-file", default="out.png")
parser.add_argument("-d", "--dimension", type=int, default=3)
options = parser.parse_args()

module = torch.jit.load(options.sample_file)
images = list(module.parameters())[0]

for index in range(options.dimension * options.dimension):
image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
array = image.numpy()
axis = plt.subplot(options.dimension, options.dimension, 1 + index)
plt.imshow(array, cmap="gray")
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)

plt.savefig(options.out_file)
print("Saved ", options.out_file)
Loading

0 comments on commit 29a38c6

Please sign in to comment.