forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
onnx2trt_utils.cpp
2216 lines (2007 loc) · 83.7 KB
/
onnx2trt_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx2trt_utils.hpp"
#include "OnnxAttrs.hpp"
#include "NvInferSafeRuntime.h"
#include <set>
namespace onnx2trt
{
void PluginDeleter::operator()(nvinfer1::IPluginV2* t)
{
t->destroy();
}
Status notInvalidType(TensorOrWeights const& input, std::vector<std::string> const& invalidTypes)
{
bool invalid = std::any_of(invalidTypes.begin(), invalidTypes.end(),
[&](std::string invalidType) { return input.getType() == invalidType; });
if (invalid)
{
return MAKE_ERROR("Found invalid input type of " + input.getType(), ErrorCode::kUNSUPPORTED_NODE);
}
return Status::success();
}
NodeImportResult activationHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::ActivationType op, float* alpha, float* beta)
{
CHECK(notInvalidType(inputs.at(0), {"INT32", "BOOL", "UINT8"}));
nvinfer1::ITensor& input = convertToTensor(inputs.at(0), ctx);
nvinfer1::IActivationLayer* layer = ctx->network()->addActivation(input, op);
if (alpha)
{
layer->setAlpha(*alpha);
}
if (beta)
{
layer->setBeta(*beta);
}
ctx->registerLayer(layer, node);
return {{layer->getOutput(0)}};
}
nvinfer1::ITensor* addClip(IImporterContext* ctx, nvinfer1::ITensor* input, float clip)
{
if (clip >= 0.f)
{
nvinfer1::IActivationLayer* layer = ctx->network()->addActivation(*input, nvinfer1::ActivationType::kCLIP);
layer->setAlpha(-clip);
layer->setBeta(clip);
return layer->getOutput(0);
}
return input;
};
NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::TopKOperation op)
{
CHECK(notInvalidType(inputs.at(0), {"UINT8"}));
nvinfer1::ITensor* tensor = &convertToTensor(inputs.at(0), ctx);
bool needCast = tensor->getType() == nvinfer1::DataType::kINT32;
if (needCast)
{
LOG_WARNING(
"TensorRT is using FLOAT32 precision to run an INT32 ArgMax / ArgMin. Rounding errors may occur for large "
"integer values");
tensor = castHelper(ctx, tensor, nvinfer1::DataType::kFLOAT);
}
// Get attributes.
OnnxAttrs attrs(node, ctx);
int32_t keepdims = attrs.get("keepdims", 1);
int32_t axis = attrs.get("axis", 0);
int32_t selectLastIndex = attrs.get<int32_t>("select_last_index", 0);
ASSERT((!selectLastIndex || (selectLastIndex && ctx->getOpsetVersion() >= 12))
&& "Per-opset 12 ONNX does not support the select_last_index attribute.",
ErrorCode::kUNSUPPORTED_NODE);
// Insert a TopK layer with k set to 1.
int32_t nbDims = tensor->getDimensions().nbDims;
CHECK(convertAxis(axis, nbDims));
uint32_t axisMask = 1 << axis;
nvinfer1::ITopKLayer* layer;
// New attribute added to Opset-12
// Whether to select the last index or the first index if the {name} appears in multiple indices, default is False
// (first index).
if (selectLastIndex)
{
// Need to flip the data input along the given axis using the Slice operator
auto const dims = shapeOf(*tensor);
ShapeTensor starts = shapeVector(-1);
ShapeTensor ends = shapeVector(static_cast<int64_t>(INT_MIN));
ShapeTensor axes = shapeVector(axis);
ShapeTensor steps = shapeVector(-1);
if (axes.size() < dims.size())
{
// axes specify a subset of the dimensions, or out of order.
// Convert starts/ends/steps to complete in-order form.
ShapeTensor const subscripts{axesToInterlaceSubscripts(axes, dims.size())};
starts = interlace(ctx, similar(ctx, dims, 0), starts, subscripts);
ends = interlace(ctx, dims, ends, subscripts);
steps = interlace(ctx, similar(ctx, dims, 1), steps, subscripts);
}
decodeOnnxStartsAndEnds(ctx, dims, steps, starts, ends);
// TensorRT uses sizes of the output dimensions instead of ends.
ShapeTensor const sizes = computeSliceSizes(ctx, starts, ends, steps, dims);
nvinfer1::ISliceLayer* slice = addSlice(ctx, *tensor, starts, sizes, steps);
nvinfer1::ITensor& flippedTensor = *slice->getOutput(0);
layer = ctx->network()->addTopK(flippedTensor, op, 1, axisMask);
}
else
{
layer = ctx->network()->addTopK(*tensor, op, 1, axisMask);
}
ctx->registerLayer(layer, node);
ASSERT(layer && "Failed to register layer.", ErrorCode::kUNSUPPORTED_NODE);
// We don't care about the TopK values, just the indices.
nvinfer1::ITensor* indices = layer->getOutput(1);
indices->setType(nvinfer1::DataType::kINT32);
// If selectLastIndex is true, the TopK operation was performed on reversed data on the provided axis.
// Convert reversed indices back to forward indices by calculating the following:
// indices = shape(tensor)[axis] - indices - 1
if (selectLastIndex)
{
// Use shapeTensor semantics to support dynamic shapes
auto const dims = shapeOf(*tensor);
auto const indicesDims = shapeOf(*indices);
auto const axisTensor = shapeVector(axis);
auto const dimOnAxis = gather(ctx, dims, axisTensor);
// Create constant of shape indicesDims with values tensor.shape[axis]
auto const tensorDimOnAxis = constantOfShape(ctx, node, &dimOnAxis.tensor(ctx), &indicesDims.tensor(ctx));
// Create constant of shape indicesDims with values of 1
auto const ones = constantOfShape(ctx, node, &shapeVector(1).tensor(ctx), &indicesDims.tensor(ctx));
std::vector<TensorOrWeights> newInputs{tensorDimOnAxis, indices, ones};
indices = &elementwiseHelper(ctx, node, newInputs, nvinfer1::ElementWiseOperation::kSUB).value().at(0).tensor();
}
if (keepdims)
{
// The default behavior of the TopK layer is to keepdims.
return {{indices}};
}
else
{
// Otherwise, we need to squeeze the axis dimension
std::vector<int32_t> axes{axis};
indices = squeezeTensor(ctx, node, *indices, axes);
return {{indices}};
}
}
Status broadcastTensor(IImporterContext* ctx, nvinfer1::ITensor*& t, const int nbDims)
{
ASSERT(ctx->getOpsetVersion() >= 7 && "Pre-opset 7 broadcasting is unsupported in this version of the ONNX parser",
ErrorCode::kUNSUPPORTED_NODE);
const auto inputDims = shapeOf(*t);
const int nbInputDims = inputDims.size();
ASSERT((nbInputDims <= nbDims) && "Cannot broadcast a higher rank tensor to a lower rank tensor.",
ErrorCode::kUNSUPPORTED_NODE);
if (nbInputDims < nbDims)
{
nvinfer1::IShuffleLayer* reshape
= addShuffle(ctx, *t, concat(ctx, fillShapeVector(ctx, 1, shapeVector(nbDims - nbInputDims)), shapeOf(*t)));
t = reshape->getOutput(0);
}
return Status::success();
}
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2)
{
const int t1Dims = t1->getDimensions().nbDims;
const int t2Dims = t2->getDimensions().nbDims;
if (t1Dims == t2Dims)
{
return Status::success();
}
if (t1Dims > t2Dims)
{
return broadcastTensor(ctx, t2, t1Dims);
}
return broadcastTensor(ctx, t1, t2Dims);
}
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2, nvinfer1::ITensor*& t3)
{
const int maxDims = std::max({t1->getDimensions().nbDims, t2->getDimensions().nbDims, t3->getDimensions().nbDims});
CHECK(broadcastTensor(ctx, t1, maxDims));
CHECK(broadcastTensor(ctx, t2, maxDims));
CHECK(broadcastTensor(ctx, t3, maxDims));
return Status::success();
}
Status isBroadcastValid(IImporterContext* ctx, const nvinfer1::Dims& firstShape, const nvinfer1::Dims& secondShape)
{
const auto firstRank = firstShape.nbDims;
const auto secondRank = secondShape.nbDims;
if (firstRank != secondRank)
{
return MAKE_ERROR("Cannot broadcast shapes that have different ranks!", ErrorCode::kUNSUPPORTED_NODE);
}
for (int32_t i = 0; i < firstRank; i++)
{
const auto firstDim = firstShape.d[i];
const auto secondDim = secondShape.d[i];
if (firstDim != secondDim && firstDim != 1 && secondDim != 1)
{
if (firstDim == -1 || secondDim == -1)
{
LOG_WARNING(
"Found dynamic dimensions when checking for broadcast compatibility! TensorRT may fail at "
"build-time if the final shapes do not conform to broadcasting rules.");
}
else
{
MAKE_ERROR("Found incompatible shapes for tensors that need to be broadcastable!",
ErrorCode::kUNSUPPORTED_NODE);
}
}
}
return Status::success();
}
// Helper functions for calculateBias:
int32_t getBias(const std::vector<int32_t>& dimension_count, const std::vector<int32_t>& pitches, int32_t axis)
{
int32_t result{0};
for (int32_t i = 0; i < static_cast<int32_t>(dimension_count.size()); i++)
{
if (i != axis)
{
result += dimension_count[i] * pitches[i];
}
}
return result;
}
void incrementOuterDimension(std::vector<int32_t>& dimensionCount, nvinfer1::Dims idxDims)
{
// Start at [x,x,0]. Increment starting from the outer dimension.
int32_t rank = dimensionCount.size();
for (int32_t i = rank - 1; i >= 0; i--)
{
int dimLimit = idxDims.d[i];
// If we're not at the limit, increment current axis and return
if (++dimensionCount[i] != dimLimit)
{
break;
}
// Else, we increment on the next dimension and reset current one
dimensionCount[i] = 0;
}
}
std::vector<int32_t> calculateBias(
const nvinfer1::Dims& daDims, const nvinfer1::Dims& idxDims, const std::vector<int32_t>& pitches, int32_t axis)
{
std::vector<int32_t> biasVector;
std::vector<int32_t> dimensionCount(daDims.nbDims, 0);
int64_t total = volume(idxDims);
for (int64_t i = 0; i < total; i++)
{
int32_t bias = getBias(dimensionCount, pitches, axis);
biasVector.push_back(bias);
incrementOuterDimension(dimensionCount, idxDims);
}
return biasVector;
}
std::vector<int32_t> calculatePitches(const nvinfer1::Dims& inputDims)
{
int32_t pitch = 1;
int32_t nbDims = inputDims.nbDims;
std::vector<int32_t> pitches(nbDims);
pitches[nbDims - 1] = pitch;
for (int32_t i = nbDims - 2; i >= 0; i--)
{
pitch *= inputDims.d[i + 1];
pitches[i] = pitch;
}
return pitches;
}
bool canUseNDResize(size_t const scaleSize, float const* scaleFactors, size_t const n)
{
// Linear resize supports up to 3D resize on the outermost dimensions (n = 3).
if (scaleSize > n)
{
for (size_t i = 0; i < scaleSize - n; i++)
{
if (scaleFactors[i] != 1)
{
return false;
}
}
}
return true;
}
nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype)
{
nvinfer1::IIdentityLayer* cast = ctx->network()->addIdentity(*input);
cast->setOutputType(0, dtype);
return cast->getOutput(0);
}
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
nvinfer1::ITensor* constant, nvinfer1::ITensor* shape)
{
ShapeTensor shapeT{*shape};
ShapeTensor zeros = similar(ctx, shapeT, 0);
// `constant` must be broadcasted to the same rank as `shape`.
ShapeTensor broadcastedShape = similar(ctx, shapeT, 1);
constant = &reshape(ctx, *constant, broadcastedShape);
auto l = addSlice(ctx, *constant, zeros, shapeT, zeros);
return l->getOutput(0);
}
Status convertAxis(int& axis, int nbDims)
{
// Support negative indexing
if (axis < 0)
{
axis += nbDims;
}
// Support nbDims as a valid axis for QuantDequantLinearHelper
ASSERT((axis >= 0 && axis <= nbDims) && "Axis must be in the range [0, nbDims].", ErrorCode::kUNSUPPORTED_NODE);
return Status::success();
}
bool convertDtype(int32_t onnx_dtype, nvinfer1::DataType* trt_dtype)
{
switch (onnx_dtype)
{
case ::ONNX_NAMESPACE::TensorProto::DOUBLE: *trt_dtype = nvinfer1::DataType::kFLOAT; break;
case ::ONNX_NAMESPACE::TensorProto::FLOAT: *trt_dtype = nvinfer1::DataType::kFLOAT; break;
case ::ONNX_NAMESPACE::TensorProto::INT8: *trt_dtype = nvinfer1::DataType::kINT8; break;
case ::ONNX_NAMESPACE::TensorProto::UINT8: *trt_dtype = nvinfer1::DataType::kUINT8; break;
case ::ONNX_NAMESPACE::TensorProto::FLOAT16: *trt_dtype = nvinfer1::DataType::kHALF; break;
case ::ONNX_NAMESPACE::TensorProto::BOOL: *trt_dtype = nvinfer1::DataType::kBOOL; break;
case ::ONNX_NAMESPACE::TensorProto::INT32: *trt_dtype = nvinfer1::DataType::kINT32; break;
// See convertOnnxWeights for sanity check if all values can be safetly downcasted to INT32
case ::ONNX_NAMESPACE::TensorProto::INT64: *trt_dtype = nvinfer1::DataType::kINT32; break;
default:
std::cerr << "Unsupported ONNX data type: " << getDtypeName(onnx_dtype) << " (" << std::to_string(onnx_dtype)
<< ")" << std::endl;
return false;
}
return true;
}
int32_t* convertINT64(const int64_t* weightValues, nvinfer1::Dims shape, IImporterContext* ctx)
{
auto ctxImpl = static_cast<ImporterContext*>(ctx);
if (!ctxImpl->isConvertINT64Logged())
{
LOG_WARNING(
"Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. "
"Attempting to cast down to INT32.");
ctxImpl->setConvertINT64Logged(true);
}
const size_t nbWeights = volume(shape);
int32_t* int32Weights{
reinterpret_cast<int32_t*>(ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, shape).values)};
bool outOfBounds{false};
for (size_t i = 0; i < nbWeights; i++)
{
if (weightValues[i] > static_cast<int64_t>(INT32_MAX) || weightValues[i] < static_cast<int64_t>(INT32_MIN))
{
int32Weights[i] = static_cast<int32_t>(
std::max(std::min(weightValues[i], static_cast<int64_t>(INT32_MAX)), static_cast<int64_t>(INT32_MIN)));
LOG_VERBOSE("Weight at index " << i << ": " << weightValues[i]
<< " is out of range. Clamping to: " << int32Weights[i]);
outOfBounds = true;
}
else
{
int32Weights[i] = static_cast<int32_t>(weightValues[i]);
}
}
if (outOfBounds && !ctxImpl->isConvertINT64OutOfBoundsLogged())
{
LOG_WARNING("One or more weights outside the range of INT32 was clamped");
ctxImpl->setConvertINT64OutOfBoundsLogged(true);
}
return int32Weights;
}
bool convertOnnxPadding(IImporterContext* ctx, int32_t nbInputDims, const std::vector<int32_t>& onnxPadding,
nvinfer1::ITensor*& startTensor, nvinfer1::ITensor*& totalPaddingTensor)
{
std::vector<int32_t> start;
std::vector<int32_t> totalPadding;
if (onnxPadding.size() % 2U != 0)
{
return false;
}
const auto diff = nbInputDims - static_cast<int32_t>(onnxPadding.size() / 2U);
if (diff < 0)
{
return false;
}
start.resize(nbInputDims, 0);
totalPadding.resize(nbInputDims, 0);
for (int32_t i = diff; i < nbInputDims; i++)
{
const auto idx = i - diff;
const auto pre = onnxPadding[idx];
const auto post = onnxPadding[onnxPadding.size() / 2U + idx];
if (pre < 0 || post < 0)
{
return false;
}
start[i] = -pre;
totalPadding[i] = pre + post;
}
startTensor
= addConstant(ctx, start, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, {nbInputDims}})->getOutput(0);
totalPaddingTensor
= addConstant(ctx, totalPadding, ::ONNX_NAMESPACE::TensorProto::INT32, nvinfer1::Dims{1, {nbInputDims}})
->getOutput(0);
return startTensor && totalPaddingTensor;
}
bool shiftIsAllZeros(const ShapedWeights& shiftInt8)
{
// Check if all of the values in the shift tensor are zeros
const auto* v = static_cast<const int8_t*>(shiftInt8.values);
auto allZeros = std::all_of(v, v + shiftInt8.count(), [](int8_t x) { return x == 0; });
return allZeros;
}
onnx2trt::ShapedWeights createZeroShifts(const onnx2trt::ShapedWeights& shiftInt8, int32_t type, IImporterContext* ctx)
{
const auto* v = static_cast<const int8_t*>(shiftInt8.values);
if (!std::all_of(v, v + shiftInt8.count(), [](int8_t x) { return x == 0; }))
{
LOG_WARNING("TensorRT currenly supports only zero shifts values for QuatizeLinear/DequantizeLinear ops");
}
auto shift = ctx->createTempWeights(type, shiftInt8.shape);
float* sh = static_cast<float*>(shift.values);
for (int i = 0, n = shift.count(); i < n; i++)
{
sh[i] = 0.0f;
}
return shift;
}
nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* data)
{
nvinfer1::ITensor* zero
= addConstant(ctx, std::vector<float>{0.f}, ::ONNX_NAMESPACE::TensorProto::FLOAT, {0, {1}})->getOutput(0);
zero = castHelper(ctx, zero, data->getType());
broadcastTensors(ctx, zero, data);
zero = ctx->network()->addElementWise(*data, *zero, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
return zero;
}
template <typename DataType>
DataType* convertINT32Data(const int32_t* weightValues, nvinfer1::Dims shape, int32_t onnxdtype, IImporterContext* ctx)
{
const size_t nbWeights = volume(shape);
DataType* newWeights{reinterpret_cast<DataType*>(ctx->createTempWeights(onnxdtype, shape).values)};
for (size_t i = 0; i < nbWeights; i++)
{
newWeights[i] = static_cast<DataType>(weightValues[i]);
}
return newWeights;
}
int32_t* convertUINT8(const uint8_t* weightValues, nvinfer1::Dims shape, IImporterContext* ctx)
{
const size_t nbWeights = volume(shape);
int32_t* int32Weights{
reinterpret_cast<int32_t*>(ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, shape).values)};
for (size_t i = 0; i < nbWeights; i++)
{
int32Weights[i] = static_cast<int32_t>(weightValues[i]);
}
return int32Weights;
}
float* convertDouble(const double* weightValues, nvinfer1::Dims shape, IImporterContext* ctx)
{
auto ctxImpl = static_cast<ImporterContext*>(ctx);
if (!ctxImpl->isConvertDoubleLogged())
{
LOG_WARNING(
"Your ONNX model has been generated with double-typed weights, while TensorRT does not natively support "
"double. "
"Attempting to cast down to float.");
ctxImpl->setConvertDoubleLogged(true);
}
const size_t nbWeights = volume(shape);
float* floatWeights{
reinterpret_cast<float*>(ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::FLOAT, shape).values)};
bool outOfBounds{false};
const double floatMax = static_cast<double>(std::numeric_limits<float>::max());
const double floatMin = static_cast<double>(std::numeric_limits<float>::lowest());
for (size_t i = 0; i < nbWeights; i++)
{
if (weightValues[i] > floatMax || weightValues[i] < floatMin)
{
floatWeights[i] = static_cast<float>(std::max(std::min(weightValues[i], floatMax), floatMin));
LOG_WARNING("Weight at index " << i << ": " << weightValues[i]
<< " is out of range. Clamping to: " << floatWeights[i]);
outOfBounds = true;
}
else
{
floatWeights[i] = static_cast<float>(weightValues[i]);
}
}
if (outOfBounds && !ctxImpl->isConvertDoubleOutOfBoundsLogged())
{
LOG_WARNING("One or more weights outside the range of FLOAT was clamped");
ctxImpl->setConvertDoubleOutOfBoundsLogged(true);
}
return floatWeights;
}
bool convertOnnxWeights(
const ::ONNX_NAMESPACE::TensorProto& onnxTensor, onnx2trt::ShapedWeights* weights, IImporterContext* ctx)
{
void* dataPtr{nullptr};
size_t nbytes{0};
auto onnxDtype = onnxTensor.data_type();
nvinfer1::Dims shape{};
shape.nbDims = onnxTensor.dims().size();
std::copy_n(onnxTensor.dims().begin(), shape.nbDims, shape.d);
// ONNX weight values can be stored in either the TensorProto itself, or in an external file in the case
// of large models. Check for this here.
auto dataLocation = onnxTensor.data_location();
// External Data
if (dataLocation == 1)
{
std::string location{""};
int64_t offset{0};
int64_t length{0};
// onnxTensor.external_data() is a String : String map that holds metadata about how to read from an external
// file
for (auto onnxMapEntry : onnxTensor.external_data())
{
auto keyName = onnxMapEntry.key();
if (keyName == "location")
{
location = onnxMapEntry.value();
}
else if (keyName == "offset")
{
offset = std::atoll(onnxMapEntry.value().c_str());
}
else if (keyName == "length")
{
length = std::atoll(onnxMapEntry.value().c_str());
}
// Not used at the moment
else if (keyName == "checksum")
{
continue;
}
else
{
LOG_ERROR("Key value of: " << keyName << " was not expected!");
return false;
}
}
// Buffer to hold the data read from the file
std::vector<char> dataBuf;
// Will update dataBuf and nbytes by reference.
if (!parseExternalWeights(ctx, location, ctx->getOnnxFileLocation(), offset, length, dataBuf, nbytes))
{
return false;
}
// For weights parsed from external files, createTempWeights is necessary to keep them in scope
ShapedWeights externalWeights;
dataPtr = dataBuf.data();
// Cast non-native TRT types to their corresponding proxy types
if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT64)
{
// Cast INT64 weights to INT32.
dataPtr = convertINT64(reinterpret_cast<const int64_t*>(dataPtr), shape, ctx);
nbytes = nbytes / (sizeof(int64_t) / sizeof(int32_t));
onnxDtype = ::ONNX_NAMESPACE::TensorProto::INT32;
}
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::UINT8)
{
// Cast UINT8 weights to INT32.
dataPtr = convertUINT8(reinterpret_cast<const uint8_t*>(dataPtr), shape, ctx);
nbytes = nbytes * (sizeof(int32_t) / sizeof(uint8_t));
onnxDtype = ::ONNX_NAMESPACE::TensorProto::INT32;
}
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::DOUBLE)
{
// Cast DOUBLE weights to FLOAT.
dataPtr = convertDouble(reinterpret_cast<const double*>(dataPtr), shape, ctx);
nbytes = nbytes / (sizeof(double) / sizeof(float));
onnxDtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
}
// Create the holder for external weights.
externalWeights = ctx->createTempWeights(onnxDtype, shape);
// Check if the size of external weights is as expected.
if (externalWeights.size_bytes() != nbytes)
{
LOG_ERROR("Unexpected size for the external weights! Expected size: "
<< externalWeights.size_bytes()
<< " bytes (shape = "
<< shape
<< "). Actual size: "
<< nbytes
<< " bytes.");
return false;
}
// Copy the weight values into externalWeights.
std::memcpy(externalWeights.values, dataPtr, nbytes);
*weights = externalWeights;
return true;
}
// Weights information is within the TensorProto itself
// Cast non-native TRT types to their corresponding proxy types
if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT64)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = convertINT64(reinterpret_cast<const int64_t*>(onnxTensor.raw_data().data()), shape, ctx);
nbytes = onnxTensor.raw_data().size() / (sizeof(int64_t) / sizeof(int32_t));
}
else if (onnxTensor.int64_data().size() > 0)
{
dataPtr = convertINT64(onnxTensor.int64_data().data(), shape, ctx);
nbytes = onnxTensor.int64_data().size() * sizeof(int32_t);
}
onnxDtype = ::ONNX_NAMESPACE::TensorProto::INT32;
}
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::UINT8)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = convertUINT8(reinterpret_cast<const uint8_t*>(onnxTensor.raw_data().data()), shape, ctx);
nbytes = onnxTensor.raw_data().size() * (sizeof(int32_t) / sizeof(uint8_t));
}
else if (onnxTensor.int32_data().size() > 0)
{
dataPtr = (void*) onnxTensor.int32_data().data();
nbytes = onnxTensor.int32_data().size() * sizeof(int32_t);
}
onnxDtype = ::ONNX_NAMESPACE::TensorProto::INT32;
}
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::DOUBLE)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = convertDouble(reinterpret_cast<const double*>(onnxTensor.raw_data().data()), shape, ctx);
nbytes = onnxTensor.raw_data().size() / (sizeof(double) / sizeof(float));
}
else if (onnxTensor.double_data().size() > 0)
{
dataPtr = convertDouble(onnxTensor.double_data().data(), shape, ctx);
nbytes = onnxTensor.double_data().size() * sizeof(float);
}
onnxDtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
}
// Check for supported types that can be found in the int32_data field in the TensorProto
// https://github.com/onnx/onnx/blob/master/onnx/onnx.proto#L528
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT32 || onnxDtype == ::ONNX_NAMESPACE::TensorProto::FLOAT16
|| onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT8 || onnxDtype == ::ONNX_NAMESPACE::TensorProto::BOOL)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = (void*) (onnxTensor.raw_data().data());
nbytes = onnxTensor.raw_data().size();
}
else
{
switch (onnxDtype)
{
case ::ONNX_NAMESPACE::TensorProto::INT32: dataPtr = (void*) (onnxTensor.int32_data().data()); break;
case ::ONNX_NAMESPACE::TensorProto::FLOAT16:
dataPtr = convertINT32Data<uint16_t>(onnxTensor.int32_data().data(), shape, onnxDtype, ctx);
break;
case ::ONNX_NAMESPACE::TensorProto::INT8:
dataPtr = convertINT32Data<int8_t>(onnxTensor.int32_data().data(), shape, onnxDtype, ctx);
break;
case ::ONNX_NAMESPACE::TensorProto::BOOL:
dataPtr = convertINT32Data<uint8_t>(onnxTensor.int32_data().data(), shape, onnxDtype, ctx);
break;
default:
LOG_ERROR("Found unsupported datatype (" << onnxDtype
<< ") when importing initializer: " << onnxTensor.name());
break;
}
nbytes = onnxTensor.int32_data().size() * getDtypeSize(onnxDtype);
}
}
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::FLOAT)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = (void*) (onnxTensor.raw_data().data());
nbytes = onnxTensor.raw_data().size();
}
else
{
dataPtr = (void*) (onnxTensor.float_data().data());
nbytes = onnxTensor.float_data().size() * sizeof(float);
}
}
else
{
LOG_ERROR("Found unsupported datatype (" << onnxDtype << ") when importing initializer: " << onnxTensor.name());
return false;
}
onnx2trt::ShapedWeights trt_weights(onnxDtype, dataPtr, shape);
// Sanity check that weights were converted properly
if (trt_weights.size_bytes() != nbytes)
{
LOG_ERROR("Size mismatch when importing initializer: " << onnxTensor.name() << ". Expected size: " << nbytes
<< " , actual size: " << trt_weights.size_bytes());
return false;
}
*weights = trt_weights;
return true;
}
nvinfer1::ITensor* convertToScalar(IImporterContext* ctx, nvinfer1::ITensor* inpTensor)
{
if (inpTensor->getDimensions().nbDims == 0)
{
return inpTensor;
}
const auto tensorVolume = volume(inpTensor->getDimensions());
if (tensorVolume != 1)
{
LOG_VERBOSE("Cannot convert tensor to scalar. Note: Tensor dimensions were: "
<< inpTensor->getDimensions() << ", with volume: " << tensorVolume);
return nullptr;
}
nvinfer1::IShuffleLayer* reshape = ctx->network()->addShuffle(*inpTensor);
reshape->setReshapeDimensions(nvinfer1::Dims{0});
// Do not need to call setZeroIsPlaceholder, since reshape dimensions are empty.
return reshape->getOutput(0);
}
nvinfer1::ITensor& convertToTensor(TensorOrWeights& input, IImporterContext* ctx)
{
if (input.is_tensor())
{
return input.tensor();
}
// Handle non-tensor indices input by adding a new constant layer to the network.
ShapedWeights& weights = input.weights();
auto const existingConstantLayer = ctx->getConstantLayer(weights.getName());
if (existingConstantLayer != nullptr)
{
return *(existingConstantLayer->getOutput(0));
}
auto* constantLayer = ctx->network()->addConstant(weights.shape, weights);
// Register layer and constant name (if set) into RefitMap:
if (weights.getName())
{
ctx->registerLayer(constantLayer, weights.getName(), nullptr);
ctx->network()->setWeightsName(weights, weights.getName());
}
return *(constantLayer->getOutput(0));
}
nvinfer1::ITensor* convertToScalar(TensorOrWeights& input, IImporterContext* ctx)
{
if (input.is_tensor())
{
return convertToScalar(ctx, &input.tensor());
}
ShapedWeights& weights = input.weights();
if (volume(weights.shape) != 1)
{
LOG_VERBOSE("Cannot convert weights to scalar. Note: Tensor dimensions were: "
<< weights.shape << ", with volume: " << volume(weights.shape));
return nullptr;
}
return ctx->network()->addConstant(nvinfer1::Dims{0, {0}}, weights)->getOutput(0);
}
int divCeil(int n, int d)
{
return (n - 1) / d + 1;
}
bool elementwiseCheck(const std::vector<TensorOrWeights>& inputs, const nvinfer1::ElementWiseOperation op)
{
switch (op)
{
// These operations only support boolean inputs
case nvinfer1::ElementWiseOperation::kAND:
case nvinfer1::ElementWiseOperation::kOR:
case nvinfer1::ElementWiseOperation::kXOR:
if (!std::all_of(inputs.begin(), inputs.end(), [](const TensorOrWeights& input) { return input.isBool(); }))
{
return false;
}
break;
// These operations do not support boolean types
case nvinfer1::ElementWiseOperation::kDIV:
case nvinfer1::ElementWiseOperation::kFLOOR_DIV:
case nvinfer1::ElementWiseOperation::kGREATER:
case nvinfer1::ElementWiseOperation::kLESS:
case nvinfer1::ElementWiseOperation::kMAX:
case nvinfer1::ElementWiseOperation::kMIN:
case nvinfer1::ElementWiseOperation::kPROD:
case nvinfer1::ElementWiseOperation::kSUB:
case nvinfer1::ElementWiseOperation::kSUM:
if (std::any_of(inputs.begin(), inputs.end(), [](const TensorOrWeights& input) { return input.isBool(); }))
{
return false;
}
break;
// Pow does not support bool or INT32 types
case nvinfer1::ElementWiseOperation::kPOW:
if (std::any_of(inputs.begin(), inputs.end(),
[](const TensorOrWeights& input) { return input.isBool() || input.isInt32(); }))
{
return false;
}
break;
// Equal supports all types.
case nvinfer1::ElementWiseOperation::kEQUAL:
break;
}
return true;
}
NodeImportResult elementwiseHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node,
const std::vector<TensorOrWeights>& inputs, nvinfer1::ElementWiseOperation binary_op)
{
ASSERT((!inputs.empty()) && "Inputs vector is empty.", ErrorCode::kINVALID_NODE);
std::vector<nvinfer1::ITensor*> inputTensors;
int maxNbDims = -1;
for (auto input : inputs)
{
maxNbDims = std::max(maxNbDims, input.shape().nbDims);
}
for (auto input : inputs)
{
auto* tensor_ptr = &convertToTensor(input, ctx);
// Broadcast all input tensors to size of maxNbDims
broadcastTensor(ctx, tensor_ptr, maxNbDims);
ASSERT(tensor_ptr->getDimensions().nbDims == maxNbDims && "Failed to broadcast tensors elementwise!",
ErrorCode::kUNSUPPORTED_NODE);
inputTensors.push_back(tensor_ptr);
}
ASSERT(elementwiseCheck(inputs, binary_op) && "Elementwise layer does not support the given inputs and operator.",
ErrorCode::kUNSUPPORTED_NODE);
// Use the first tensor input as the base for the elementwise operation
nvinfer1::ITensor* combined = inputTensors.at(0);
if (inputTensors.size() == 1)
{
// Note: Single input must be wrapped in identity to avoid messing up network outputs
return {{identity(ctx, combined)}};
}
for (size_t i = 1; i < inputTensors.size(); ++i)
{
nvinfer1::ITensor* tensor = inputTensors.at(i);
ASSERT((tensor->getDimensions().nbDims == combined->getDimensions().nbDims)
&& "The number of dimensions should remain the same adding inputs.",
ErrorCode::kUNSUPPORTED_NODE);
auto* layer = ctx->network()->addElementWise(*combined, *tensor, binary_op);
ctx->registerLayer(layer, node);
ASSERT(layer && "Failed to register layer.", ErrorCode::kUNSUPPORTED_NODE);
combined = layer->getOutput(0);
}
return {{combined}};
}
nvinfer1::ITensor* flattenTensor(
IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, nvinfer1::ITensor& tensor, int axis, bool regLayer)
{
const auto dims = shapeOf(tensor);
const auto d0 = product(ctx, dims, 0, axis, 1);
const auto d1 = product(ctx, dims, axis, dims.size(), 1);
// ShuffleLayer here interprets dim extent 0 as empty dim to support empty tensor
nvinfer1::IShuffleLayer* flattenLayer = addShuffle(ctx, tensor, concat(ctx, d0, d1), /*zeroIsPlaceholder=*/false);
if (regLayer)
{
ctx->registerLayer(flattenLayer, node);
}
return flattenLayer->getOutput(0);
}
nvinfer1::ITensor* gatherDimension(IImporterContext* ctx, nvinfer1::ITensor* shapeTensor, int dim, nvinfer1::Dims shape)
{
auto& axisValue = *addConstantScalar(ctx, dim, ::ONNX_NAMESPACE::TensorProto_DataType_INT32, shape)->getOutput(0);
return ctx->network()->addGather(*shapeTensor, axisValue, 0)->getOutput(0);
}
// Helper function to generate padding values for convTranspose
void generatePadding(nvinfer1::Dims inputShape, nvinfer1::Dims outputShape, nvinfer1::Dims kernelSize,
nvinfer1::Dims strides, nvinfer1::Dims dilations, const int nbSpatialDims, nvinfer1::Dims& begPadding,
nvinfer1::Dims& endPadding, nvinfer1::Dims& outputPadding, nvinfer1::PaddingMode paddingMode)
{
nvinfer1::Dims totalPadding{nbSpatialDims, {}};
// Pre and post padding calculated as per https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConvTranspose
// Note that output shape is inconsistent in the spec - can either be in full dimensions form (i.e. NCHW) or just spatial
// dimensions form (i.e. HW). Calculate potential offset here.
auto const outputOffset = outputShape.nbDims - nbSpatialDims;
for (int32_t i = 0; i < nbSpatialDims; i++)
{
totalPadding.d[i] = strides.d[i] * (inputShape.d[2 + i] - 1) + outputPadding.d[i]
+ ((kernelSize.d[i] - 1) * dilations.d[i] + 1) - outputShape.d[outputOffset + i];
// Same upper is calculated differently
if (paddingMode != nvinfer1::PaddingMode::kSAME_UPPER)
{
begPadding.d[i] = totalPadding.d[i] / 2;
endPadding.d[i] = totalPadding.d[i] - (totalPadding.d[i] / 2);
}
else
{
begPadding.d[i] = totalPadding.d[i] - (totalPadding.d[i] / 2);
endPadding.d[i] = (totalPadding.d[i] / 2);
}
}
}
float getActivationDefaultAlpha(nvinfer1::ActivationType type)
{
switch (type)
{
case nvinfer1::ActivationType::kRELU: return 0.f;
case nvinfer1::ActivationType::kSIGMOID: return 0.f;
case nvinfer1::ActivationType::kTANH: return 0.f;
case nvinfer1::ActivationType::kLEAKY_RELU: return 0.01f;
case nvinfer1::ActivationType::kELU: return 1.0f;
case nvinfer1::ActivationType::kSELU: return 1.67326319217681884765625f;
case nvinfer1::ActivationType::kSOFTSIGN: return 0.f;
case nvinfer1::ActivationType::kSOFTPLUS: return 0.f;
case nvinfer1::ActivationType::kCLIP: return 0.f;
case nvinfer1::ActivationType::kHARD_SIGMOID: return 0.2f;
case nvinfer1::ActivationType::kSCALED_TANH: return 1.0f;
case nvinfer1::ActivationType::kTHRESHOLDED_RELU: return 1.0f;
}
throw std::runtime_error{"Unrecognized activation type"};
}
float getActivationDefaultBeta(nvinfer1::ActivationType type)
{
switch (type)
{
case nvinfer1::ActivationType::kRELU: return 0.f;
case nvinfer1::ActivationType::kSIGMOID: return 0.f;
case nvinfer1::ActivationType::kTANH: return 0.f;
case nvinfer1::ActivationType::kLEAKY_RELU: return 0.f;
case nvinfer1::ActivationType::kELU: return 0.f;
case nvinfer1::ActivationType::kSELU: return 1.05070102214813232421875f;
case nvinfer1::ActivationType::kSOFTSIGN: return 0.f;
case nvinfer1::ActivationType::kSOFTPLUS: return 0.f;
case nvinfer1::ActivationType::kCLIP: return 0.f;
case nvinfer1::ActivationType::kHARD_SIGMOID: return 0.5f;