Skip to content

Commit

Permalink
vstrt/vs_tensorrt.cpp: stricter type check
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Sep 7, 2023
1 parent 8544207 commit 471769f
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions vstrt/vs_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,17 @@ static void VS_CC vsTrtCreate(
auto input_type = d->engines[0]->getBindingDataType(0);
#endif // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85

auto input_sample_type = getSampleType(input_type) == 0 ? stInteger : stFloat;
VSSampleType input_sample_type;
{
auto sample_type = getSampleType(input_type);
if (sample_type == 0) {
input_sample_type = stInteger;
} else if (sample_type == 1) {
input_sample_type = stFloat;
} else {
return set_error("unknown input sample type");
}
}
auto input_bits_per_sample = getBytesPerSample(input_type) * 8;

if (auto err = checkNodes(in_vis, input_sample_type, input_bits_per_sample); err.has_value()) {
Expand All @@ -463,7 +473,17 @@ static void VS_CC vsTrtCreate(
auto output_type = d->engines[0]->getBindingDataType(1);
#endif // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85

auto output_sample_type = getSampleType(output_type) == 0 ? stInteger : stFloat;
VSSampleType output_sample_type;
{
auto sample_type = getSampleType(output_type);
if (sample_type == 0) {
output_sample_type = stInteger;
} else if (sample_type == 1) {
output_sample_type = stFloat;
} else {
return set_error("unknown output sample type");
}
}
auto output_bits_per_sample = getBytesPerSample(output_type) * 8;

setDimensions(
Expand Down

0 comments on commit 471769f

Please sign in to comment.