Skip to content

Commit

Permalink
Added options to specify dimension ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
nightduck committed Jan 10, 2021
1 parent f6884f0 commit 454e72a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
34 changes: 31 additions & 3 deletions include/tkDNN/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

namespace tk { namespace dnn {

enum dimFormat_t {
CHW, // Default, where the TensorRT engine doesn't specify batch dimension, which is assumed to be 1
NCHW, // Explicitly define batch dimension. Useful for engines generated from keras
NHWC // Channels last format. This is the default for keras models
};

/**
Data representation between layers
n = batch size
Expand All @@ -21,9 +27,31 @@ struct dataDim_t {

dataDim_t() : n(1), c(1), h(1), w(1), l(1) {};

dataDim_t(nvinfer1::Dims &d) :
n(1), c(d.d[0] ? d.d[0] : 1), h(d.d[1] ? d.d[1] : 1),
w(d.d[2] ? d.d[2] : 1), l(d.d[3] ? d.d[3] : 1) {};
dataDim_t(nvinfer1::Dims &d, dimFormat_t df) :
n(1), c(1), h(1), w(1), l(1) {
switch(df) {
case CHW:
c = d.d[0];
h = d.d[1];
w = d.d[2];
l = d.d[3];
break;
case NCHW:
n = d.d[0];
c = d.d[1];
h = d.d[2];
w = d.d[3];
l = d.d[4];
break;
case NHWC:
n = d.d[0];
h = d.d[1];
w = d.d[2];
c = d.d[3];
l = d.d[4];
break;
}
};

dataDim_t(int _n, int _c, int _h, int _w, int _l = 1) :
n(_n), c(_c), h(_h), w(_w), l(_l) {};
Expand Down
2 changes: 1 addition & 1 deletion include/tkDNN/NetworkRT.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class NetworkRT {

PluginFactory *pluginFactory;

NetworkRT(Network *net, const char *name, const char *input_name="data", const char *output_name="out");
NetworkRT(Network *net, const char *name, dimFormat_t dim_format=CHW, const char *input_name="data", const char *output_name="out");
virtual ~NetworkRT();

int getMaxBatchSize() {
Expand Down
8 changes: 4 additions & 4 deletions src/NetworkRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace tk { namespace dnn {

std::map<Layer*, nvinfer1::ITensor*>tensors;

NetworkRT::NetworkRT(Network *net, const char *name, const char *input_name, const char *output_name) {
NetworkRT::NetworkRT(Network *net, const char *name, dimFormat_t dim_format, const char *input_name, const char *output_name) {

float rt_ver = float(NV_TENSORRT_MAJOR) +
float(NV_TENSORRT_MINOR)/10 +
Expand Down Expand Up @@ -167,17 +167,17 @@ NetworkRT::NetworkRT(Network *net, const char *name, const char *input_name, con


Dims iDim = engineRT->getBindingDimensions(buf_input_idx);
input_dim = dataDim_t(iDim);
input_dim = dataDim_t(iDim, dim_format);
input_dim.print();

Dims oDim = engineRT->getBindingDimensions(buf_output_idx);
output_dim = dataDim_t(oDim);
output_dim = dataDim_t(oDim, dim_format);
output_dim.print();

// create GPU buffers and a stream
for(int i=0; i<engineRT->getNbBindings(); i++) {
Dims dim = engineRT->getBindingDimensions(i);
buffersDIM[i] = dataDim_t(dim);
buffersDIM[i] = dataDim_t(dim, dim_format);
std::cout<<"RtBuffer "<<i<<" dim: "; buffersDIM[i].print();
checkCuda(cudaMalloc(&buffersRT[i], engineRT->getMaxBatchSize()*buffersDIM[i].tot()*sizeof(dnnType)));
}
Expand Down

0 comments on commit 454e72a

Please sign in to comment.