24
24
#include " tensorrtinputconns.h"
25
25
#include " utils/apitools.h"
26
26
#include " NvInferPlugin.h"
27
+ #include " NvOnnxParser.h"
27
28
#include " protoUtils.h"
28
29
#include < cuda_runtime_api.h>
29
30
#include < string>
@@ -39,7 +40,12 @@ namespace dd
39
40
fileops::list_directory (repo, true , false , false , lfiles);
40
41
for (std::string s : lfiles)
41
42
{
42
- if (s.find (engineFileName) != std::string::npos)
43
+ // Ommiting directory name
44
+ auto fstart = s.find_last_of (" /" );
45
+ if (fstart == std::string::npos)
46
+ fstart = 0 ;
47
+
48
+ if (s.find (engineFileName, fstart) != std::string::npos)
43
49
{
44
50
std::string bs_str;
45
51
for (auto it = s.crbegin (); it != s.crend (); ++it)
@@ -134,6 +140,10 @@ namespace dd
134
140
_max_batch_size = nmbs;
135
141
this ->_logger ->info (" setting max batch size to {}" , _max_batch_size);
136
142
}
143
+ if (ad.has (" nclasses" ))
144
+ {
145
+ _nclasses = ad.get (" nclasses" ).get <int >();
146
+ }
137
147
138
148
if (ad.has (" dla" ))
139
149
_dla = ad.get (" dla" ).get <int >();
@@ -244,6 +254,114 @@ namespace dd
244
254
return 0 ;
245
255
}
246
256
257
+ template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
258
+ class TMLModel >
259
+ nvinfer1::ICudaEngine *
260
+ TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
261
+ TMLModel>::read_engine_from_caffe(const std::string &out_blob)
262
+ {
263
+ int fixcode = fixProto (this ->_mlmodel ._repo + " /" + " net_tensorRT.proto" ,
264
+ this ->_mlmodel ._def );
265
+ switch (fixcode)
266
+ {
267
+ case 1 :
268
+ this ->_logger ->error (" TRT backend could not open model prototxt" );
269
+ break ;
270
+ case 2 :
271
+ this ->_logger ->error (" TRT backend could not write "
272
+ " transformed model prototxt" );
273
+ break ;
274
+ default :
275
+ break ;
276
+ }
277
+
278
+ nvinfer1::INetworkDefinition *network = _builder->createNetworkV2 (0U );
279
+ nvcaffeparser1::ICaffeParser *caffeParser
280
+ = nvcaffeparser1::createCaffeParser ();
281
+
282
+ const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor
283
+ = caffeParser->parse (
284
+ std::string (this ->_mlmodel ._repo + " /" + " net_tensorRT.proto" )
285
+ .c_str (),
286
+ this ->_mlmodel ._weights .c_str (), *network, _datatype);
287
+ if (!blobNameToTensor)
288
+ throw MLLibInternalException (" Error while parsing caffe model "
289
+ " for conversion to TensorRT" );
290
+
291
+ network->markOutput (*blobNameToTensor->find (out_blob.c_str ()));
292
+
293
+ if (out_blob == " detection_out" )
294
+ network->markOutput (*blobNameToTensor->find (" keep_count" ));
295
+ _builder->setMaxBatchSize (_max_batch_size);
296
+ _builderc->setMaxWorkspaceSize (_max_workspace_size);
297
+
298
+ network->getLayer (0 )->setPrecision (nvinfer1::DataType::kFLOAT );
299
+
300
+ nvinfer1::ILayer *outl = NULL ;
301
+ int idx = network->getNbLayers () - 1 ;
302
+ while (outl == NULL )
303
+ {
304
+ nvinfer1::ILayer *l = network->getLayer (idx);
305
+ if (strcmp (l->getName (), out_blob.c_str ()) == 0 )
306
+ {
307
+ outl = l;
308
+ break ;
309
+ }
310
+ idx--;
311
+ }
312
+ // force output to be float32
313
+ outl->setPrecision (nvinfer1::DataType::kFLOAT );
314
+ nvinfer1::ICudaEngine *engine
315
+ = _builder->buildEngineWithConfig (*network, *_builderc);
316
+
317
+ network->destroy ();
318
+ if (caffeParser != nullptr )
319
+ caffeParser->destroy ();
320
+
321
+ return engine;
322
+ }
323
+
324
+ template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
325
+ class TMLModel >
326
+ nvinfer1::ICudaEngine *
327
+ TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
328
+ TMLModel>::read_engine_from_onnx()
329
+ {
330
+ const auto explicitBatch
331
+ = 1U << static_cast <uint32_t >(
332
+ nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH );
333
+
334
+ nvinfer1::INetworkDefinition *network
335
+ = _builder->createNetworkV2 (explicitBatch);
336
+
337
+ nvonnxparser::IParser *onnxParser
338
+ = nvonnxparser::createParser (*network, trtLogger);
339
+ onnxParser->parseFromFile (this ->_mlmodel ._model .c_str (),
340
+ int (nvinfer1::ILogger::Severity::kWARNING ));
341
+
342
+ if (onnxParser->getNbErrors () != 0 )
343
+ {
344
+ for (int i = 0 ; i < onnxParser->getNbErrors (); ++i)
345
+ {
346
+ this ->_logger ->error (onnxParser->getError (i)->desc ());
347
+ }
348
+ throw MLLibInternalException (
349
+ " Error while parsing onnx model for conversion to "
350
+ " TensorRT" );
351
+ }
352
+ _builder->setMaxBatchSize (_max_batch_size);
353
+ _builderc->setMaxWorkspaceSize (_max_workspace_size);
354
+
355
+ nvinfer1::ICudaEngine *engine
356
+ = _builder->buildEngineWithConfig (*network, *_builderc);
357
+
358
+ network->destroy ();
359
+ if (onnxParser != nullptr )
360
+ onnxParser->destroy ();
361
+
362
+ return engine;
363
+ }
364
+
247
365
template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
248
366
class TMLModel >
249
367
int TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
@@ -293,7 +411,12 @@ namespace dd
293
411
" timeseries not yet implemented over tensorRT backend" );
294
412
}
295
413
296
- _nclasses = findNClasses (this ->_mlmodel ._def , _bbox);
414
+ if (_nclasses == 0 )
415
+ {
416
+ this ->_logger ->info (" try to determine number of classes..." );
417
+ _nclasses = findNClasses (this ->_mlmodel ._def , _bbox);
418
+ }
419
+
297
420
if (_bbox)
298
421
_top_k = findTopK (this ->_mlmodel ._def );
299
422
@@ -335,65 +458,25 @@ namespace dd
335
458
336
459
if (!engineRead)
337
460
{
461
+ nvinfer1::ICudaEngine *le = nullptr ;
338
462
339
- int fixcode
340
- = fixProto (this ->_mlmodel ._repo + " /" + " net_tensorRT.proto" ,
341
- this ->_mlmodel ._def );
342
- switch (fixcode)
463
+ if (this ->_mlmodel ._model .find (" net_tensorRT.proto" )
464
+ != std::string::npos
465
+ || !this ->_mlmodel ._def .empty ())
343
466
{
344
- case 1 :
345
- this ->_logger ->error (
346
- " TRT backend could not open model prototxt" );
347
- break ;
348
- case 2 :
349
- this ->_logger ->error (" TRT backend could not write "
350
- " transformed model prototxt" );
351
- break ;
352
- default :
353
- break ;
467
+ le = read_engine_from_caffe (out_blob);
354
468
}
355
-
356
- nvinfer1::INetworkDefinition *network
357
- = _builder->createNetworkV2 (0U );
358
- nvcaffeparser1::ICaffeParser *caffeParser
359
- = nvcaffeparser1::createCaffeParser ();
360
-
361
- const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor
362
- = caffeParser->parse (std::string (this ->_mlmodel ._repo + " /"
363
- + " net_tensorRT.proto" )
364
- .c_str (),
365
- this ->_mlmodel ._weights .c_str (), *network,
366
- _datatype);
367
- if (!blobNameToTensor)
368
- throw MLLibInternalException (" Error while parsing caffe model "
369
- " for conversion to TensorRT" );
370
-
371
- network->markOutput (*blobNameToTensor->find (out_blob.c_str ()));
372
-
373
- if (out_blob == " detection_out" )
374
- network->markOutput (*blobNameToTensor->find (" keep_count" ));
375
- _builder->setMaxBatchSize (_max_batch_size);
376
- _builderc->setMaxWorkspaceSize (_max_workspace_size);
377
-
378
- network->getLayer (0 )->setPrecision (nvinfer1::DataType::kFLOAT );
379
-
380
- nvinfer1::ILayer *outl = NULL ;
381
- int idx = network->getNbLayers () - 1 ;
382
- while (outl == NULL )
469
+ else if (this ->_mlmodel ._model .find (" net_tensorRT.onnx" )
470
+ != std::string::npos)
383
471
{
384
- nvinfer1::ILayer *l = network->getLayer (idx);
385
- if (strcmp (l->getName (), out_blob.c_str ()) == 0 )
386
- {
387
- outl = l;
388
- break ;
389
- }
390
- idx--;
472
+ le = read_engine_from_onnx ();
473
+ }
474
+ else
475
+ {
476
+ throw MLLibInternalException (
477
+ " No model to parse for conversion to TensorRT" );
391
478
}
392
- // force output to be float32
393
- outl->setPrecision (nvinfer1::DataType::kFLOAT );
394
479
395
- nvinfer1::ICudaEngine *le
396
- = _builder->buildEngineWithConfig (*network, *_builderc);
397
480
_engine = std::shared_ptr<nvinfer1::ICudaEngine>(
398
481
le, [=](nvinfer1::ICudaEngine *e) { e->destroy (); });
399
482
@@ -407,9 +490,6 @@ namespace dd
407
490
trtModelStream->size ());
408
491
trtModelStream->destroy ();
409
492
}
410
-
411
- network->destroy ();
412
- caffeParser->destroy ();
413
493
}
414
494
415
495
_context = std::shared_ptr<nvinfer1::IExecutionContext>(
0 commit comments