@@ -163,29 +163,119 @@ namespace dd
163
163
_classif->to (device, dtype);
164
164
}
165
165
166
+ void TorchModule::proto_model_load (const TorchModel &model)
167
+ {
168
+ _logger->info (" loading " + model._proto );
169
+ try
170
+ {
171
+ _graph = std::make_shared<CaffeToTorch>(model._proto );
172
+ }
173
+ catch (std::exception &e)
174
+ {
175
+ _logger->info (" unable to load " + model._proto );
176
+ throw ;
177
+ }
178
+ }
179
+
180
+ void TorchModule::graph_model_load (const TorchModel &tmodel)
181
+ {
182
+ if (!tmodel._traced .empty () && _graph->needs_reload ())
183
+ {
184
+ _logger->info (" loading " + tmodel._traced );
185
+ try
186
+ {
187
+ torch::load (_graph, tmodel._traced , _device);
188
+ }
189
+ catch (std::exception &e)
190
+ {
191
+ _logger->error (" unable to load " + tmodel._traced );
192
+ throw ;
193
+ }
194
+ }
195
+ }
196
+
197
+ void TorchModule::native_model_load (const TorchModel &tmodel)
198
+ {
199
+ if (!tmodel._native .empty ())
200
+ {
201
+ _logger->info (" loading " + tmodel._native );
202
+ try
203
+ {
204
+ torch::load (_native, tmodel._native );
205
+ }
206
+ catch (std::exception &e)
207
+ {
208
+ _logger->error (" unable to load " + tmodel._native );
209
+ throw ;
210
+ }
211
+ }
212
+ }
213
+
214
+ void TorchModule::classif_model_load (const TorchModel &model)
215
+ {
216
+ _logger->info (" loading " + model._weights );
217
+ try
218
+ {
219
+ torch::load (_classif, model._weights , _device);
220
+ }
221
+ catch (std::exception &e)
222
+ {
223
+ _logger->error (" unable to load " + model._weights );
224
+ throw ;
225
+ }
226
+ }
227
+
228
+ void TorchModule::classif_layer_load ()
229
+ {
230
+ if (!_classif_layer_file.empty ())
231
+ {
232
+ _logger->info (" loading " + _classif_layer_file);
233
+ torch::load (_classif, _classif_layer_file, _device);
234
+ }
235
+ }
236
+
237
+ void TorchModule::traced_model_load (TorchModel &model)
238
+ {
239
+ _logger->info (" loading " + model._traced );
240
+ try
241
+ {
242
+ _traced = std::make_shared<torch::jit::script::Module>(
243
+ torch::jit::load (model._traced , _device));
244
+ }
245
+ catch (std::exception &e)
246
+ {
247
+ _logger->error (" unable to load " + model._traced );
248
+ throw ;
249
+ }
250
+ }
251
+
166
252
template <class TInputConnectorStrategy >
167
253
void TorchModule::post_transform (const std::string tmpl,
168
254
const APIData &template_params,
169
255
const TInputConnectorStrategy &inputc,
170
256
const TorchModel &tmodel,
171
257
const torch::Device &device)
172
258
{
259
+ _device = device;
173
260
this ->_native = std::shared_ptr<NativeModule>(
174
261
NativeFactory::from_template<TInputConnectorStrategy>(
175
262
tmpl, template_params, inputc));
176
263
177
264
if (_native)
178
- if (!tmodel._native .empty ())
179
- torch::load (_native, tmodel._native , device);
265
+ {
266
+ _logger->info (" created net using template " + tmpl);
267
+ native_model_load (tmodel);
268
+ }
180
269
181
270
if (_graph)
182
271
{
183
272
std::vector<long int > dims = inputc._dataset .datasize (0 );
184
273
dims.insert (dims.begin (), 1 ); // dummy batch size
185
274
_graph->finalize (dims);
275
+ if (_graph->needs_reload ())
276
+ _logger->info (" net was reallocated due to input dim changes" );
186
277
// reload params after finalize
187
- if (!tmodel._traced .empty ())
188
- torch::load (_graph, tmodel._traced , _device);
278
+ graph_model_load (tmodel);
189
279
}
190
280
to (_device);
191
281
}
@@ -361,11 +451,7 @@ namespace dd
361
451
// First dimension is batch id
362
452
int outdim = to_tensor_safe (forward (input_example)).sizes ()[1 ];
363
453
_classif = torch::nn::Linear (outdim, nclasses);
364
-
365
- if (!_classif_layer_file.empty ())
366
- {
367
- torch::load (_classif, _classif_layer_file, _device);
368
- }
454
+ classif_layer_load ();
369
455
}
370
456
371
457
std::vector<Tensor> TorchModule::parameters ()
@@ -401,13 +487,13 @@ namespace dd
401
487
void TorchModule::load (TorchModel &model)
402
488
{
403
489
if (!model._traced .empty () && model._proto .empty ())
404
- _traced = std::make_shared<torch::jit::script::Module>(
405
- torch::jit::load (model. _traced , _device));
490
+ traced_model_load (model);
491
+
406
492
if (!model._weights .empty ())
407
493
{
408
494
if (_classif)
409
495
{
410
- torch::load (_classif, model. _weights , _device );
496
+ classif_model_load ( model);
411
497
}
412
498
else if (_require_classif_layer)
413
499
{
@@ -416,16 +502,12 @@ namespace dd
416
502
}
417
503
if (!model._proto .empty ())
418
504
{
419
- _graph = std::make_shared<CaffeToTorch>(model._proto );
420
- if (!model._traced .empty ())
421
- torch::load (_graph, model._traced , _device);
505
+ proto_model_load (model);
506
+ graph_model_load (model);
422
507
}
508
+
423
509
if (!model._native .empty ())
424
- {
425
- std::shared_ptr<NativeModule> m;
426
- torch::load (m, model._native );
427
- _native = m;
428
- }
510
+ native_model_load (model);
429
511
}
430
512
431
513
void TorchModule::eval ()
@@ -544,6 +626,33 @@ namespace dd
544
626
}
545
627
}
546
628
629
+ template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
630
+ class TMLModel >
631
+ void
632
+ TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
633
+ TMLModel>::solver_load(std::unique_ptr<optim::Optimizer> &optimizer)
634
+ {
635
+ if (!this ->_mlmodel ._sstate .empty ())
636
+ {
637
+
638
+ this ->_logger ->info (" Reload solver from {}" , this ->_mlmodel ._sstate );
639
+ size_t start = this ->_mlmodel ._sstate .rfind (" -" ) + 1 ;
640
+ size_t end = this ->_mlmodel ._sstate .rfind (" ." );
641
+ int it = std::stoi (this ->_mlmodel ._sstate .substr (start, end - start));
642
+ this ->_logger ->info (" Restarting optimization from iter {}" , it);
643
+ this ->_logger ->info (" loading " + this ->_mlmodel ._sstate );
644
+ try
645
+ {
646
+ torch::load (*optimizer, this ->_mlmodel ._sstate );
647
+ }
648
+ catch (std::exception &e)
649
+ {
650
+ this ->_logger ->error (" unable to load " + this ->_mlmodel ._sstate );
651
+ throw ;
652
+ }
653
+ }
654
+ }
655
+
547
656
/* - from mllib -*/
548
657
template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
549
658
class TMLModel >
@@ -581,6 +690,7 @@ namespace dd
581
690
_device = gpu ? torch::Device (DeviceType::CUDA, gpuid)
582
691
: torch::Device (DeviceType::CPU);
583
692
_module._device = _device;
693
+ _module._logger = this ->_logger ;
584
694
585
695
if (_template.find (" recurrent" ) != std::string::npos)
586
696
{
@@ -665,15 +775,6 @@ namespace dd
665
775
}
666
776
667
777
// Load weights
668
- if (!this ->_mlmodel ._traced .empty ())
669
- this ->_logger ->info (" Loading ml model from file {}." ,
670
- this ->_mlmodel ._traced );
671
- if (!this ->_mlmodel ._proto .empty ())
672
- this ->_logger ->info (" Loading ml model from file {}." ,
673
- this ->_mlmodel ._proto );
674
- if (!this ->_mlmodel ._weights .empty ())
675
- this ->_logger ->info (" Loading weights from file {}." ,
676
- this ->_mlmodel ._weights );
677
778
_module.load (this ->_mlmodel );
678
779
_module.freeze_traced (freeze_traced);
679
780
@@ -919,15 +1020,7 @@ namespace dd
919
1020
920
1021
int it = 0 ;
921
1022
// reload solver and set it value accordingly
922
- if (!this ->_mlmodel ._sstate .empty ())
923
- {
924
- this ->_logger ->info (" Reload solver from {}" , this ->_mlmodel ._sstate );
925
- size_t start = this ->_mlmodel ._sstate .rfind (" -" ) + 1 ;
926
- size_t end = this ->_mlmodel ._sstate .rfind (" ." );
927
- it = std::stoi (this ->_mlmodel ._sstate .substr (start, end - start));
928
- this ->_logger ->info (" Restarting optimization from iter {}" , it);
929
- torch::load (*optimizer, this ->_mlmodel ._sstate );
930
- }
1023
+ solver_load (optimizer);
931
1024
optimizer->zero_grad ();
932
1025
_module.train ();
933
1026
@@ -1422,7 +1515,6 @@ namespace dd
1422
1515
unsupo.finalize (ad.getobj (" parameters" ).getobj (" output" ), out,
1423
1516
static_cast <MLModel *>(&this ->_mlmodel ));
1424
1517
}
1425
-
1426
1518
out.add (" status" , 0 );
1427
1519
return 0 ;
1428
1520
}
0 commit comments