diff --git a/include/tkDNN/DarknetParser.h b/include/tkDNN/DarknetParser.h index 089c4d69..b90e3f4c 100644 --- a/include/tkDNN/DarknetParser.h +++ b/include/tkDNN/DarknetParser.h @@ -28,6 +28,7 @@ namespace tk { namespace dnn { int new_coords= 0; float scale_xy = 1; float nms_thresh = 0.45; + int scale_wh_in_scale_channels = 0; std::vector layers; std::string activation = "linear"; diff --git a/include/tkDNN/Layer.h b/include/tkDNN/Layer.h index 25c45652..f98d36c5 100644 --- a/include/tkDNN/Layer.h +++ b/include/tkDNN/Layer.h @@ -19,6 +19,7 @@ enum layerType_t { LAYER_ACTIVATION_CRELU, LAYER_ACTIVATION_LEAKY, LAYER_ACTIVATION_MISH, + LAYER_ACTIVATION_SWISH, LAYER_FLATTEN, LAYER_RESHAPE, LAYER_MULADD, @@ -27,6 +28,7 @@ enum layerType_t { LAYER_ROUTE, LAYER_REORG, LAYER_SHORTCUT, + LAYER_SCALECHANNELS, LAYER_UPSAMPLE, LAYER_REGION, LAYER_YOLO @@ -68,6 +70,7 @@ class Layer { case LAYER_ACTIVATION_CRELU: return "ActivationCReLU"; case LAYER_ACTIVATION_LEAKY: return "ActivationLeaky"; case LAYER_ACTIVATION_MISH: return "ActivationMish"; + case LAYER_ACTIVATION_SWISH: return "ActivationSwish"; case LAYER_FLATTEN: return "Flatten"; case LAYER_RESHAPE: return "Reshape"; case LAYER_MULADD: return "MulAdd"; @@ -76,6 +79,7 @@ class Layer { case LAYER_ROUTE: return "Route"; case LAYER_REORG: return "Reorg"; case LAYER_SHORTCUT: return "Shortcut"; + case LAYER_SCALECHANNELS: return "ScaleChannels"; case LAYER_UPSAMPLE: return "Upsample"; case LAYER_REGION: return "Region"; case LAYER_YOLO: return "Yolo"; @@ -212,7 +216,8 @@ class Dense : public LayerWgs { typedef enum { ACTIVATION_ELU = 100, ACTIVATION_LEAKY = 101, - ACTIVATION_MISH = 102 + ACTIVATION_MISH = 102, + ACTIVATION_SWISH = 103 } tkdnnActivationMode_t; /** @@ -233,6 +238,8 @@ class Activation : public Layer { return LAYER_ACTIVATION_LEAKY; else if (act_mode == ACTIVATION_MISH) return LAYER_ACTIVATION_MISH; + else if (act_mode == ACTIVATION_SWISH) + return LAYER_ACTIVATION_SWISH; else return LAYER_ACTIVATION; }; @@ -557,6 +564,25 @@ class Shortcut : public Layer { Layer *backLayer; }; +/** + ScaleChannels layer + channelwise-multiplication with another layer +*/ +class ScaleChannels : public Layer { + +public: + ScaleChannels(Network *net, Layer *backLayer, int scale_wh); + virtual ~ScaleChannels(); + virtual layerType_t getLayerType() { return LAYER_SCALECHANNELS; }; + + virtual dnnType* infer(dataDim_t &dim, dnnType* srcData); + +public: + Layer *backLayer; + int scale_wh; +}; + + /** Upsample layer Maintains same dimension but change C*H*W distribution diff --git a/include/tkDNN/NetworkRT.h b/include/tkDNN/NetworkRT.h index 4c6c8162..ecb5305c 100644 --- a/include/tkDNN/NetworkRT.h +++ b/include/tkDNN/NetworkRT.h @@ -26,6 +26,7 @@ using namespace nvinfer1; #include "pluginsRT/ActivationLeakyRT.h" #include "pluginsRT/ActivationReLUCeilingRT.h" #include "pluginsRT/ActivationMishRT.h" +#include "pluginsRT/ActivationSwishRT.h" #include "pluginsRT/ReorgRT.h" #include "pluginsRT/RegionRT.h" #include "pluginsRT/RouteRT.h" @@ -108,6 +109,7 @@ class NetworkRT { nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, Reorg *l); nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, Region *l); nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, Shortcut *l); + nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, ScaleChannels *l); nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, Yolo *l); nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, Upsample *l); nvinfer1::ILayer* convert_layer(nvinfer1::ITensor *input, DeformConv2d *l); diff --git a/include/tkDNN/kernels.h b/include/tkDNN/kernels.h index 5d673c87..19f74bbb 100644 --- a/include/tkDNN/kernels.h +++ b/include/tkDNN/kernels.h @@ -9,6 +9,7 @@ void activationReLUCeilingForward(dnnType *srcData, dnnType *dstData, int size, void activationLOGISTICForward(dnnType *srcData, dnnType *dstData, int size, cudaStream_t stream = cudaStream_t(0)); void activationSIGMOIDForward(dnnType *srcData, dnnType *dstData, int size, cudaStream_t stream = cudaStream_t(0)); void activationMishForward(dnnType* srcData, dnnType* dstData, int size, cudaStream_t stream= cudaStream_t(0)); +void activationSwishForward(dnnType* srcData, dnnType* dstData, int size, cudaStream_t stream= cudaStream_t(0)); void fill(dnnType *data, int size, dnnType val, cudaStream_t stream = cudaStream_t(0)); @@ -27,6 +28,10 @@ void shortcutForward(dnnType *srcData, dnnType *dstData, int n1, int c1, int h1, int n2, int c2, int h2, int w2, int s2, cudaStream_t stream = cudaStream_t(0)); +void scaleChannelsForward(dnnType *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, + dnnType *scales_c, dnnType *out, + cudaStream_t stream = cudaStream_t(0)); + void upsampleForward(dnnType *srcData, dnnType *dstData, int n, int c, int h, int w, int s, int forward, float scale, cudaStream_t stream = cudaStream_t(0)); diff --git a/include/tkDNN/pluginsRT/ActivationSwishRT.h b/include/tkDNN/pluginsRT/ActivationSwishRT.h new file mode 100644 index 00000000..9796e1f0 --- /dev/null +++ b/include/tkDNN/pluginsRT/ActivationSwishRT.h @@ -0,0 +1,60 @@ +#include +#include "../kernels.h" + +class ActivationSwishRT : public IPlugin { + +public: + ActivationSwishRT() { + + + } + + ~ActivationSwishRT(){ + + } + + int getNbOutputs() const override { + return 1; + } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { + return inputs[0]; + } + + void configure(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) override { + size = 1; + for(int i=0; i(inputs[0]), + reinterpret_cast(outputs[0]), batchSize*size, stream); + return 0; + } + + + virtual size_t getSerializationSize() override { + return 1*sizeof(int); + } + + virtual void serialize(void* buffer) override { + char *buf = reinterpret_cast(buffer); + tk::dnn::writeBUF(buf, size); + } + + int size; +}; diff --git a/src/Activation.cpp b/src/Activation.cpp index 28c76242..543e29ab 100644 --- a/src/Activation.cpp +++ b/src/Activation.cpp @@ -52,6 +52,9 @@ dnnType* Activation::infer(dataDim_t &dim, dnnType* srcData) { else if(act_mode == ACTIVATION_MISH) { activationMishForward(srcData, dstData, dim.tot()); + } + else if(act_mode == ACTIVATION_SWISH) { + activationSwishForward(srcData, dstData, dim.tot()); } else { dnnType alpha = dnnType(1); dnnType beta = dnnType(0); diff --git a/src/DarknetParser.cpp b/src/DarknetParser.cpp index 7b5410cb..3fda8f02 100644 --- a/src/DarknetParser.cpp +++ b/src/DarknetParser.cpp @@ -80,6 +80,8 @@ namespace tk { namespace dnn { fields.groups = std::stoi(value); else if(name.find("group_id") != std::string::npos) fields.group_id = std::stoi(value); + else if(name.find("scale_wh") != std::string::npos) + fields.scale_wh_in_scale_channels = std::stoi(value); else if(name.find("scale_x_y") != std::string::npos) fields.scale_xy = std::stof(value); else if(name.find("beta_nms") != std::string::npos) @@ -134,7 +136,11 @@ namespace tk { namespace dnn { f.padding_x, f.padding_y, tk::dnn::POOLING_MAX)); } else if(f.type == "avgpool") { - netLayers.push_back(new tk::dnn::Pooling(net, f.size_x, f.size_y, f.stride_x, f.stride_y, + auto output_dim = net->getOutputDim(); + int stride = 1; + assert(f.padding_x == 0 && f.padding_y == 0); + + netLayers.push_back(new tk::dnn::Pooling(net, output_dim.h, output_dim.w, stride, stride, f.padding_x, f.padding_y, tk::dnn::POOLING_AVERAGE)); } else if(f.type == "shortcut") { @@ -146,6 +152,18 @@ namespace tk { namespace dnn { //std::cout<<"shortcut to "<getLayerName()<<"\n"; netLayers.push_back(new tk::dnn::Shortcut(net, netLayers[layerIdx])); + } else if(f.type == "scale_channels") { + if(f.layers.size() != 1) FatalError("no layers to scale_channels\n"); + int layerIdx = f.layers[0]; + if(layerIdx < 0) + layerIdx = netLayers.size() + layerIdx; + if(layerIdx < 0 || layerIdx >= netLayers.size()) FatalError("impossible to scale_channels\n"); + + int scale_wh = f.scale_wh_in_scale_channels; + if(scale_wh != 0) FatalError("Currently only support scale_wh=0 in scale_channels\n") + + netLayers.push_back(new tk::dnn::ScaleChannels(net, netLayers[layerIdx], scale_wh)); + } else if(f.type == "upsample") { netLayers.push_back(new tk::dnn::Upsample(net, f.stride_x)); @@ -185,8 +203,10 @@ namespace tk { namespace dnn { if(netLayers.size() > 0 && f.activation != "linear") { tkdnnActivationMode_t act; if(f.activation == "relu") act = tkdnnActivationMode_t(CUDNN_ACTIVATION_RELU); + else if(f.activation == "logistic") act = tkdnnActivationMode_t(CUDNN_ACTIVATION_SIGMOID); else if(f.activation == "leaky") act = tk::dnn::ACTIVATION_LEAKY; else if(f.activation == "mish") act = tk::dnn::ACTIVATION_MISH; + else if(f.activation == "swish") act = tk::dnn::ACTIVATION_SWISH; else { FatalError("activation not supported: " + f.activation); } netLayers[netLayers.size()-1] = new tk::dnn::Activation(net, act); }; diff --git a/src/NetworkRT.cpp b/src/NetworkRT.cpp index 501ade4f..429bf463 100644 --- a/src/NetworkRT.cpp +++ b/src/NetworkRT.cpp @@ -226,7 +226,7 @@ ILayer* NetworkRT::convert_layer(ITensor *input, Layer *l) { return convert_layer(input, (Conv2d*) l); if(type == LAYER_POOLING) return convert_layer(input, (Pooling*) l); - if(type == LAYER_ACTIVATION || type == LAYER_ACTIVATION_CRELU || type == LAYER_ACTIVATION_LEAKY || type == LAYER_ACTIVATION_MISH) + if(type == LAYER_ACTIVATION || type == LAYER_ACTIVATION_CRELU || type == LAYER_ACTIVATION_LEAKY || type == LAYER_ACTIVATION_MISH || type == LAYER_ACTIVATION_SWISH) return convert_layer(input, (Activation*) l); if(type == LAYER_SOFTMAX) return convert_layer(input, (Softmax*) l); @@ -242,6 +242,8 @@ ILayer* NetworkRT::convert_layer(ITensor *input, Layer *l) { return convert_layer(input, (Region*) l); if(type == LAYER_SHORTCUT) return convert_layer(input, (Shortcut*) l); + if(type == LAYER_SCALECHANNELS) + return convert_layer(input, (ScaleChannels*) l); if(type == LAYER_YOLO) return convert_layer(input, (Yolo*) l); if(type == LAYER_UPSAMPLE) @@ -421,6 +423,12 @@ ILayer* NetworkRT::convert_layer(ITensor *input, Activation *l) { checkNULL(lRT); return lRT; } + else if(l->act_mode == ACTIVATION_SWISH) { + IPlugin *plugin = new ActivationSwishRT(); + IPluginLayer *lRT = networkRT->addPlugin(&input, 1, *plugin); + checkNULL(lRT); + return lRT; + } else { FatalError("this Activation mode is not yet implemented"); return NULL; @@ -525,6 +533,14 @@ ILayer* NetworkRT::convert_layer(ITensor *input, Shortcut *l) { } } +ILayer* NetworkRT::convert_layer(ITensor *input, ScaleChannels *l) { + ITensor *back_tens = tensors[l->backLayer]; + + IElementWiseLayer *lRT = networkRT->addElementWise(*input, *back_tens, ElementWiseOperation::kPROD); + checkNULL(lRT); + return lRT; +} + ILayer* NetworkRT::convert_layer(ITensor *input, Yolo *l) { //std::cout<<"convert Yolo\n"; @@ -653,6 +669,11 @@ IPlugin* PluginFactory::createPlugin(const char* layerName, const void* serialDa a->size = readBUF(buf); return a; } + if(name.find("ActivationSwish") == 0) { + ActivationSwishRT *a = new ActivationSwishRT(); + a->size = readBUF(buf); + return a; + } if(name.find("ActivationCReLU") == 0) { ActivationReLUCeiling *a = new ActivationReLUCeiling(readBUF(buf)); a->size = readBUF(buf); diff --git a/src/ScaleChannels.cpp b/src/ScaleChannels.cpp new file mode 100644 index 00000000..fe45e213 --- /dev/null +++ b/src/ScaleChannels.cpp @@ -0,0 +1,37 @@ +#include + +#include "Layer.h" +#include "kernels.h" + +namespace tk { namespace dnn { +ScaleChannels::ScaleChannels(Network *net, Layer *backLayer, int scale_wh) : Layer(net) { + + this->backLayer = backLayer; + this->scale_wh = scale_wh; + output_dim = backLayer->output_dim; + checkCuda( cudaMalloc(&dstData, output_dim.tot()*sizeof(dnnType)) ); + + if( backLayer->output_dim.c != input_dim.c ) + FatalError("ScaleChannels dim missmatch"); + +} + +ScaleChannels::~ScaleChannels() { + + checkCuda( cudaFree(dstData) ); +} + +dnnType* ScaleChannels::infer(dataDim_t &dim, dnnType* srcData) { + + int size = output_dim.n * output_dim.c * output_dim.h * output_dim.w; + int channel_size = output_dim.h * output_dim.w; + int batch_size = output_dim.c * output_dim.h * output_dim.w; + scaleChannelsForward(this->backLayer->dstData, size, channel_size, batch_size, scale_wh, srcData, dstData); + + //update data dimensions + dim = output_dim; + + return dstData; +} + +}} \ No newline at end of file diff --git a/src/kernels/activation_swish.cu b/src/kernels/activation_swish.cu new file mode 100644 index 00000000..2c7e2e9e --- /dev/null +++ b/src/kernels/activation_swish.cu @@ -0,0 +1,27 @@ +#include "kernels.h" +#include + +// https://github.com/AlexeyAB/darknet/blob/master/src/activation_kernels.cu +__device__ float logistic_activate_kernel(float x){return 1.f/(1.f + expf(-x));} + +__global__ +void activation_swish(dnnType *input, dnnType *output, int size) { + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i < size) + { + float x_val = input[i]; + float sigmoid = logistic_activate_kernel(x_val); + output[i] = x_val * sigmoid; + } +} + +/** + swish activation function +*/ +void activationSwishForward(dnnType* srcData, dnnType* dstData, int size, cudaStream_t stream) +{ + int blocks = (size+255)/256; + int threads = 256; + + activation_swish<<>>(srcData, dstData, size); +} \ No newline at end of file diff --git a/src/kernels/scale_channels.cu b/src/kernels/scale_channels.cu new file mode 100644 index 00000000..d00172e3 --- /dev/null +++ b/src/kernels/scale_channels.cu @@ -0,0 +1,27 @@ +#include "kernels.h" +#include "assert.h" + +// https://github.com/AlexeyAB/darknet/blob/master/src/blas_kernels.cu +__global__ void scale_channels_kernel(float *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, float *scales_c, float *out) +{ + const int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < size) { + if (scale_wh) { + int osd_index = index % channel_size + (index / batch_size)*channel_size; + + out[index] = in_w_h_c[index] * scales_c[osd_index]; + } + else { + out[index] = in_w_h_c[index] * scales_c[index / channel_size]; + } + } +} + +void scaleChannelsForward(dnnType *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, + dnnType *scales_c, dnnType *out, cudaStream_t stream) +{ + int blocks = (size+255)/256; + int threads = 256; + + scale_channels_kernel <<>>(in_w_h_c, size, channel_size, batch_size, scale_wh, scales_c, out); +} diff --git a/tests/darknet/cfg/enet-coco-wo-dropout.cfg b/tests/darknet/cfg/enet-coco-wo-dropout.cfg new file mode 100644 index 00000000..8824fd94 --- /dev/null +++ b/tests/darknet/cfg/enet-coco-wo-dropout.cfg @@ -0,0 +1,1072 @@ +[net] +# Testing +#batch=1 +#subdivisions=1 +# Training +batch=64 +subdivisions=8 +width=416 +height=416 +channels=3 +momentum=0.9 +decay=0.0005 +angle=0 +saturation = 1.5 +exposure = 1.5 +hue=.1 + +learning_rate=0.001 +burn_in=1000 +max_batches = 500200 +policy=steps +steps=400000,450000 +scales=.1,.1 + +### CONV1 - 1 (1) +# conv1 +[convolutional] +filters=32 +size=3 +pad=1 +stride=2 +batch_normalize=1 +activation=swish + + +### CONV2 - MBConv1 - 1 (1) +# conv2_1_expand +[convolutional] +filters=32 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv2_1_dwise +[convolutional] +groups=32 +filters=32 +size=3 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=4 (recommended r=16) +[convolutional] +filters=8 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=32 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv2_1_linear +[convolutional] +filters=16 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + + +### CONV3 - MBConv6 - 1 (2) +# conv2_2_expand +[convolutional] +filters=96 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv2_2_dwise +[convolutional] +groups=96 +filters=96 +size=3 +pad=1 +stride=2 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=8 (recommended r=16) +[convolutional] +filters=16 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=96 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv2_2_linear +[convolutional] +filters=24 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV3 - MBConv6 - 2 (2) +# conv3_1_expand +[convolutional] +filters=144 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv3_1_dwise +[convolutional] +groups=144 +filters=144 +size=3 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=8 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=144 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv3_1_linear +[convolutional] +filters=24 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + + +### CONV4 - MBConv6 - 1 (2) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_3_1 +[shortcut] +from=-8 +activation=linear + +# conv_3_2_expand +[convolutional] +filters=144 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_3_2_dwise +[convolutional] +groups=144 +filters=144 +size=5 +pad=1 +stride=2 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=8 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=144 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_3_2_linear +[convolutional] +filters=40 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV4 - MBConv6 - 2 (2) +# conv_4_1_expand +[convolutional] +filters=192 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_4_1_dwise +[convolutional] +groups=192 +filters=192 +size=5 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=16 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=192 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_4_1_linear +[convolutional] +filters=40 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + + + +### CONV5 - MBConv6 - 1 (3) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_4_2 +[shortcut] +from=-8 +activation=linear + +# conv_4_3_expand +[convolutional] +filters=192 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_4_3_dwise +[convolutional] +groups=192 +filters=192 +size=3 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=16 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=192 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_4_3_linear +[convolutional] +filters=80 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV5 - MBConv6 - 2 (3) +# conv_4_4_expand +[convolutional] +filters=384 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_4_4_dwise +[convolutional] +groups=384 +filters=384 +size=3 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=24 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=384 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_4_4_linear +[convolutional] +filters=80 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV5 - MBConv6 - 3 (3) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_4_4 +[shortcut] +from=-8 +activation=linear + +# conv_4_5_expand +[convolutional] +filters=384 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_4_5_dwise +[convolutional] +groups=384 +filters=384 +size=3 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=24 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=384 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_4_5_linear +[convolutional] +filters=80 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + + +### CONV6 - MBConv6 - 1 (3) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_4_6 +[shortcut] +from=-8 +activation=linear + +# conv_4_7_expand +[convolutional] +filters=384 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_4_7_dwise +[convolutional] +groups=384 +filters=384 +size=5 +pad=1 +stride=2 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=24 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=384 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_4_7_linear +[convolutional] +filters=112 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV6 - MBConv6 - 2 (3) +# conv_5_1_expand +[convolutional] +filters=576 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_5_1_dwise +[convolutional] +groups=576 +filters=576 +size=5 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=32 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=576 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_5_1_linear +[convolutional] +filters=112 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV6 - MBConv6 - 3 (3) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_5_1 +[shortcut] +from=-8 +activation=linear + +# conv_5_2_expand +[convolutional] +filters=576 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_5_2_dwise +[convolutional] +groups=576 +filters=576 +size=5 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=32 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=576 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_5_2_linear +[convolutional] +filters=112 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV7 - MBConv6 - 1 (4) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_5_2 +[shortcut] +from=-8 +activation=linear + +# conv_5_3_expand +[convolutional] +filters=576 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_5_3_dwise +[convolutional] +groups=576 +filters=576 +size=5 +pad=1 +stride=2 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=32 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=576 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_5_3_linear +[convolutional] +filters=192 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV7 - MBConv6 - 2 (4) +# conv_6_1_expand +[convolutional] +filters=960 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_6_1_dwise +[convolutional] +groups=960 +filters=960 +size=5 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=64 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=960 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_6_1_linear +[convolutional] +filters=192 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV7 - MBConv6 - 3 (4) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_6_1 +[shortcut] +from=-8 +activation=linear + +# conv_6_2_expand +[convolutional] +filters=960 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_6_2_dwise +[convolutional] +groups=960 +filters=960 +size=5 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=64 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=960 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_6_2_linear +[convolutional] +filters=192 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV7 - MBConv6 - 4 (4) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_6_1 +[shortcut] +from=-8 +activation=linear + +# conv_6_2_expand +[convolutional] +filters=960 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_6_2_dwise +[convolutional] +groups=960 +filters=960 +size=5 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=64 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=960 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_6_2_linear +[convolutional] +filters=192 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + + +### CONV8 - MBConv6 - 1 (1) +# dropout only before residual connection +#[dropout] +#probability=.0 + +# block_6_2 +[shortcut] +from=-8 +activation=linear + +# conv_6_3_expand +[convolutional] +filters=960 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +# conv_6_3_dwise +[convolutional] +groups=960 +filters=960 +size=3 +stride=1 +pad=1 +batch_normalize=1 +activation=swish + + +#squeeze-n-excitation +[avgpool] + +# squeeze ratio r=16 (recommended r=16) +[convolutional] +filters=64 +size=1 +stride=1 +activation=swish + +# excitation +[convolutional] +filters=960 +size=1 +stride=1 +activation=logistic + +# multiply channels +[scale_channels] +from=-4 + + +# conv_6_3_linear +[convolutional] +filters=320 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=linear + + +### CONV9 - Conv2d 1x1 +# conv_6_4 +[convolutional] +filters=1280 +size=1 +stride=1 +pad=0 +batch_normalize=1 +activation=swish + +########################## + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[shortcut] +activation=leaky +from=-2 + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + + + +[yolo] +mask = 3,4,5 +anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319 +classes=80 +num=6 +jitter=.3 +ignore_thresh = .7 +truth_thresh = 1 +random=0 + +[route] +layers = -4 + +[convolutional] +batch_normalize=1 +filters=128 +size=1 +stride=1 +pad=1 +activation=leaky + +[upsample] +stride=2 + +[shortcut] +activation=leaky +from=84 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[shortcut] +activation=leaky +from=-3 + +[shortcut] +activation=leaky +from=84 + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + +[yolo] +mask = 1,2,3 +anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319 +classes=80 +num=6 +jitter=.3 +ignore_thresh = .7 +truth_thresh = 1 +random=0 + diff --git a/tests/darknet/enet_coco_wo_dropout.cpp b/tests/darknet/enet_coco_wo_dropout.cpp new file mode 100644 index 00000000..af2c17fd --- /dev/null +++ b/tests/darknet/enet_coco_wo_dropout.cpp @@ -0,0 +1,33 @@ +#include +#include +#include "tkdnn.h" +#include "test.h" +#include "DarknetParser.h" + +int main() { + std::string bin_path = "enet_coco_wo_dropout"; + std::vector input_bins = { + bin_path + "/layers/input.bin" + }; + std::vector output_bins = { + bin_path + "/debug/layer127_out.bin", + bin_path + "/debug/layer136_out.bin" + }; + std::string wgs_path = bin_path + "/layers"; + std::string cfg_path = std::string(TKDNN_PATH) + "/tests/darknet/cfg/enet-coco-wo-dropout.cfg"; + std::string name_path = std::string(TKDNN_PATH) + "/tests/darknet/names/coco.names"; + // downloadWeightsifDoNotExist(input_bins[0], bin_path, "https://cloud.hipert.unimore.it/s/d97CFzYqCPCp5Hg/download"); + + // parse darknet network + tk::dnn::Network *net = tk::dnn::darknetParser(cfg_path, wgs_path, name_path); + net->print(); + + //convert network to tensorRT + tk::dnn::NetworkRT *netRT = new tk::dnn::NetworkRT(net, net->getNetworkRTName(bin_path.c_str())); + + int ret = testInference(input_bins, output_bins, net, netRT); + net->releaseLayers(); + delete net; + delete netRT; + return ret; +}