Skip to content

Commit a8b81f2

Browse files
Bycobsileht
authored andcommitted
feat(tensorrt): Add support for onnx image classification models
1 parent b43a6bc commit a8b81f2

7 files changed

+215
-63
lines changed

src/backends/tensorrt/tensorrtinputconns.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ namespace dd
3838
for (int c = 0; c < channels; ++c)
3939
for (int h = 0; h < _height; ++h)
4040
for (int w = 0; w < _width; ++w)
41-
fbuf[offset++] = cvbuf[(converted.cols * h + w) * channels + c];
41+
fbuf[offset++]
42+
= _scale * cvbuf[(converted.cols * h + w) * channels + c];
4243
}
4344

4445
void ImgTensorRTInputFileConn::applyMeanToRTBuf(int channels, int i)

src/backends/tensorrt/tensorrtlib.cc

+138-58
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorrtinputconns.h"
2525
#include "utils/apitools.h"
2626
#include "NvInferPlugin.h"
27+
#include "NvOnnxParser.h"
2728
#include "protoUtils.h"
2829
#include <cuda_runtime_api.h>
2930
#include <string>
@@ -39,7 +40,12 @@ namespace dd
3940
fileops::list_directory(repo, true, false, false, lfiles);
4041
for (std::string s : lfiles)
4142
{
42-
if (s.find(engineFileName) != std::string::npos)
43+
// Ommiting directory name
44+
auto fstart = s.find_last_of("/");
45+
if (fstart == std::string::npos)
46+
fstart = 0;
47+
48+
if (s.find(engineFileName, fstart) != std::string::npos)
4349
{
4450
std::string bs_str;
4551
for (auto it = s.crbegin(); it != s.crend(); ++it)
@@ -134,6 +140,10 @@ namespace dd
134140
_max_batch_size = nmbs;
135141
this->_logger->info("setting max batch size to {}", _max_batch_size);
136142
}
143+
if (ad.has("nclasses"))
144+
{
145+
_nclasses = ad.get("nclasses").get<int>();
146+
}
137147

138148
if (ad.has("dla"))
139149
_dla = ad.get("dla").get<int>();
@@ -244,6 +254,114 @@ namespace dd
244254
return 0;
245255
}
246256

257+
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
258+
class TMLModel>
259+
nvinfer1::ICudaEngine *
260+
TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
261+
TMLModel>::read_engine_from_caffe(const std::string &out_blob)
262+
{
263+
int fixcode = fixProto(this->_mlmodel._repo + "/" + "net_tensorRT.proto",
264+
this->_mlmodel._def);
265+
switch (fixcode)
266+
{
267+
case 1:
268+
this->_logger->error("TRT backend could not open model prototxt");
269+
break;
270+
case 2:
271+
this->_logger->error("TRT backend could not write "
272+
"transformed model prototxt");
273+
break;
274+
default:
275+
break;
276+
}
277+
278+
nvinfer1::INetworkDefinition *network = _builder->createNetworkV2(0U);
279+
nvcaffeparser1::ICaffeParser *caffeParser
280+
= nvcaffeparser1::createCaffeParser();
281+
282+
const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor
283+
= caffeParser->parse(
284+
std::string(this->_mlmodel._repo + "/" + "net_tensorRT.proto")
285+
.c_str(),
286+
this->_mlmodel._weights.c_str(), *network, _datatype);
287+
if (!blobNameToTensor)
288+
throw MLLibInternalException("Error while parsing caffe model "
289+
"for conversion to TensorRT");
290+
291+
network->markOutput(*blobNameToTensor->find(out_blob.c_str()));
292+
293+
if (out_blob == "detection_out")
294+
network->markOutput(*blobNameToTensor->find("keep_count"));
295+
_builder->setMaxBatchSize(_max_batch_size);
296+
_builderc->setMaxWorkspaceSize(_max_workspace_size);
297+
298+
network->getLayer(0)->setPrecision(nvinfer1::DataType::kFLOAT);
299+
300+
nvinfer1::ILayer *outl = NULL;
301+
int idx = network->getNbLayers() - 1;
302+
while (outl == NULL)
303+
{
304+
nvinfer1::ILayer *l = network->getLayer(idx);
305+
if (strcmp(l->getName(), out_blob.c_str()) == 0)
306+
{
307+
outl = l;
308+
break;
309+
}
310+
idx--;
311+
}
312+
// force output to be float32
313+
outl->setPrecision(nvinfer1::DataType::kFLOAT);
314+
nvinfer1::ICudaEngine *engine
315+
= _builder->buildEngineWithConfig(*network, *_builderc);
316+
317+
network->destroy();
318+
if (caffeParser != nullptr)
319+
caffeParser->destroy();
320+
321+
return engine;
322+
}
323+
324+
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
325+
class TMLModel>
326+
nvinfer1::ICudaEngine *
327+
TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
328+
TMLModel>::read_engine_from_onnx()
329+
{
330+
const auto explicitBatch
331+
= 1U << static_cast<uint32_t>(
332+
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
333+
334+
nvinfer1::INetworkDefinition *network
335+
= _builder->createNetworkV2(explicitBatch);
336+
337+
nvonnxparser::IParser *onnxParser
338+
= nvonnxparser::createParser(*network, trtLogger);
339+
onnxParser->parseFromFile(this->_mlmodel._model.c_str(),
340+
int(nvinfer1::ILogger::Severity::kWARNING));
341+
342+
if (onnxParser->getNbErrors() != 0)
343+
{
344+
for (int i = 0; i < onnxParser->getNbErrors(); ++i)
345+
{
346+
this->_logger->error(onnxParser->getError(i)->desc());
347+
}
348+
throw MLLibInternalException(
349+
"Error while parsing onnx model for conversion to "
350+
"TensorRT");
351+
}
352+
_builder->setMaxBatchSize(_max_batch_size);
353+
_builderc->setMaxWorkspaceSize(_max_workspace_size);
354+
355+
nvinfer1::ICudaEngine *engine
356+
= _builder->buildEngineWithConfig(*network, *_builderc);
357+
358+
network->destroy();
359+
if (onnxParser != nullptr)
360+
onnxParser->destroy();
361+
362+
return engine;
363+
}
364+
247365
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
248366
class TMLModel>
249367
int TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
@@ -293,7 +411,12 @@ namespace dd
293411
"timeseries not yet implemented over tensorRT backend");
294412
}
295413

296-
_nclasses = findNClasses(this->_mlmodel._def, _bbox);
414+
if (_nclasses == 0)
415+
{
416+
this->_logger->info("try to determine number of classes...");
417+
_nclasses = findNClasses(this->_mlmodel._def, _bbox);
418+
}
419+
297420
if (_bbox)
298421
_top_k = findTopK(this->_mlmodel._def);
299422

@@ -335,65 +458,25 @@ namespace dd
335458

336459
if (!engineRead)
337460
{
461+
nvinfer1::ICudaEngine *le = nullptr;
338462

339-
int fixcode
340-
= fixProto(this->_mlmodel._repo + "/" + "net_tensorRT.proto",
341-
this->_mlmodel._def);
342-
switch (fixcode)
463+
if (this->_mlmodel._model.find("net_tensorRT.proto")
464+
!= std::string::npos
465+
|| !this->_mlmodel._def.empty())
343466
{
344-
case 1:
345-
this->_logger->error(
346-
"TRT backend could not open model prototxt");
347-
break;
348-
case 2:
349-
this->_logger->error("TRT backend could not write "
350-
"transformed model prototxt");
351-
break;
352-
default:
353-
break;
467+
le = read_engine_from_caffe(out_blob);
354468
}
355-
356-
nvinfer1::INetworkDefinition *network
357-
= _builder->createNetworkV2(0U);
358-
nvcaffeparser1::ICaffeParser *caffeParser
359-
= nvcaffeparser1::createCaffeParser();
360-
361-
const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor
362-
= caffeParser->parse(std::string(this->_mlmodel._repo + "/"
363-
+ "net_tensorRT.proto")
364-
.c_str(),
365-
this->_mlmodel._weights.c_str(), *network,
366-
_datatype);
367-
if (!blobNameToTensor)
368-
throw MLLibInternalException("Error while parsing caffe model "
369-
"for conversion to TensorRT");
370-
371-
network->markOutput(*blobNameToTensor->find(out_blob.c_str()));
372-
373-
if (out_blob == "detection_out")
374-
network->markOutput(*blobNameToTensor->find("keep_count"));
375-
_builder->setMaxBatchSize(_max_batch_size);
376-
_builderc->setMaxWorkspaceSize(_max_workspace_size);
377-
378-
network->getLayer(0)->setPrecision(nvinfer1::DataType::kFLOAT);
379-
380-
nvinfer1::ILayer *outl = NULL;
381-
int idx = network->getNbLayers() - 1;
382-
while (outl == NULL)
469+
else if (this->_mlmodel._model.find("net_tensorRT.onnx")
470+
!= std::string::npos)
383471
{
384-
nvinfer1::ILayer *l = network->getLayer(idx);
385-
if (strcmp(l->getName(), out_blob.c_str()) == 0)
386-
{
387-
outl = l;
388-
break;
389-
}
390-
idx--;
472+
le = read_engine_from_onnx();
473+
}
474+
else
475+
{
476+
throw MLLibInternalException(
477+
"No model to parse for conversion to TensorRT");
391478
}
392-
// force output to be float32
393-
outl->setPrecision(nvinfer1::DataType::kFLOAT);
394479

395-
nvinfer1::ICudaEngine *le
396-
= _builder->buildEngineWithConfig(*network, *_builderc);
397480
_engine = std::shared_ptr<nvinfer1::ICudaEngine>(
398481
le, [=](nvinfer1::ICudaEngine *e) { e->destroy(); });
399482

@@ -407,9 +490,6 @@ namespace dd
407490
trtModelStream->size());
408491
trtModelStream->destroy();
409492
}
410-
411-
network->destroy();
412-
caffeParser->destroy();
413493
}
414494

415495
_context = std::shared_ptr<nvinfer1::IExecutionContext>(

src/backends/tensorrt/tensorrtlib.h

+4
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ namespace dd
143143
std::mutex
144144
_net_mutex; /**< mutex around net, e.g. no concurrent predict calls as
145145
net is not re-instantiated. Use batches instead. */
146+
147+
nvinfer1::ICudaEngine *read_engine_from_caffe(const std::string &out_blob);
148+
149+
nvinfer1::ICudaEngine *read_engine_from_onnx();
146150
};
147151

148152
}

src/backends/tensorrt/tensorrtmodel.cc

+22-3
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,21 @@ namespace dd
2828
static std::string weights = ".caffemodel";
2929
static std::string corresp = "corresp";
3030
static std::string meanf = "mean.binaryproto";
31+
32+
static std::string model_name = "net_tensorRT";
33+
static std::string caffe_model_name = model_name + ".proto";
34+
static std::string onnx_model_name = model_name + ".onnx";
35+
3136
std::unordered_set<std::string> lfiles;
3237
int e = fileops::list_directory(_repo, true, false, false, lfiles);
3338
if (e != 0)
3439
{
35-
logger->error("error reading or listing caffe models in repository {}",
40+
logger->error("error reading or listing models in repository {}",
3641
_repo);
3742
return 1;
3843
}
39-
std::string deployf, weightsf, correspf;
40-
long int weight_t = -1;
44+
std::string deployf, weightsf, correspf, modelf;
45+
long int weight_t = -1, model_t = -1;
4146
auto hit = lfiles.begin();
4247
while (hit != lfiles.end())
4348
{
@@ -57,6 +62,16 @@ namespace dd
5762
}
5863
else if ((*hit).find(corresp) != std::string::npos)
5964
correspf = (*hit);
65+
else if ((*hit).find(caffe_model_name) != std::string::npos
66+
|| (*hit).find(onnx_model_name) != std::string::npos)
67+
{
68+
long int wt = fileops::file_last_modif(*hit);
69+
if (wt > model_t)
70+
{
71+
modelf = (*hit);
72+
model_t = wt;
73+
}
74+
}
6075
else if ((*hit).find("~") != std::string::npos
6176
|| (*hit).find(".prototxt") == std::string::npos)
6277
{
@@ -67,12 +82,16 @@ namespace dd
6782
deployf = (*hit);
6883
++hit;
6984
}
85+
7086
if (_def.empty())
7187
_def = deployf;
7288
if (_weights.empty())
7389
_weights = weightsf;
7490
if (_corresp.empty())
7591
_corresp = correspf;
92+
if (_model.empty())
93+
_model = modelf;
94+
7695
return 0;
7796
}
7897
}

src/backends/tensorrt/tensorrtmodel.h

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ namespace dd
5353

5454
int read_from_repository(const std::shared_ptr<spdlog::logger> &logger);
5555

56+
std::string _model;
5657
std::string _def;
5758
std::string _weights;
5859
bool _has_mean_file = false;

tests/CMakeLists.txt

+8-1
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,19 @@ if (GTEST_FOUND)
267267
"squeezenet_ssd_trt"
268268
)
269269
DOWNLOAD_DATASET(
270-
"Downloading age test set"
270+
"Age test set"
271271
"https://deepdetect.com/models/init/desktop/images/classification/age_real.tar.gz"
272272
"examples/trt/age_real"
273273
"age_real.tar.gz"
274274
"deploy.prototxt"
275275
)
276+
DOWNLOAD_DATASET(
277+
"ONNX resnet model"
278+
"https://deepdetect.com/models/init/desktop/images/classification/resnet_onnx_trt.tar.gz"
279+
"examples/trt"
280+
"resnet_onnx_trt.tar.gz"
281+
"resnet_onnx_trt"
282+
)
276283

277284
if(USE_JSON_API)
278285
REGISTER_TEST(ut_tensorrtapi ut-tensorrtapi.cc)

0 commit comments

Comments
 (0)