Skip to content

Commit 0052a03

Browse files
fantessileht
authored andcommitted
fix(torch): load weights only once
1 parent cc086ff commit 0052a03

6 files changed

+167
-47
lines changed

src/backends/torch/torchgraphbackend.cc

+4
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ namespace dd
189189

190190
void TorchGraphBackend::allocate_modules()
191191
{
192+
_allocation_done = false;
192193
for (BaseGraph::Vertex v : _sortedOps)
193194
{
194195
if (!_graph[v].alloc_needed)
@@ -220,6 +221,7 @@ namespace dd
220221
_modules[opname] = AnyModule(m);
221222
_graph[v].alloc_needed = false;
222223
_rnn_has_memories[opname] = false;
224+
_allocation_done = true;
223225
}
224226
else if (optype == "RNN")
225227
{
@@ -233,6 +235,7 @@ namespace dd
233235
_modules[opname] = AnyModule(m);
234236
_graph[v].alloc_needed = false;
235237
_rnn_has_memories[opname] = false;
238+
_allocation_done = true;
236239
}
237240
else if (optype == "InnerProduct")
238241
{
@@ -243,6 +246,7 @@ namespace dd
243246
Linear(LinearOptions(dim(v, 0, 2), num_output(v)).bias(true)));
244247
_modules[opname] = AnyModule(m);
245248
_graph[v].alloc_needed = false;
249+
_allocation_done = true;
246250
}
247251
else if (optype == "Tile")
248252
_graph[v].alloc_needed = false;

src/backends/torch/torchgraphbackend.h

+11-2
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,18 @@ namespace dd
177177
_parameters_used = false;
178178
}
179179

180+
/**
181+
* tells if some allocation was done (needs to be called just after
182+
* set_inputdim or finalize
183+
*/
184+
bool needs_reload()
185+
{
186+
return _allocation_done;
187+
}
188+
180189
protected:
181190
/**
182191
* internal torch module allocation, called whithin (finalize)
183-
* @param force
184192
*/
185193
void allocate_modules();
186194

@@ -215,8 +223,9 @@ namespace dd
215223
std::unordered_map<std::string, bool>
216224
_rnn_has_memories; /**< true if previsous hidden values are available
217225
*/
218-
};
219226

227+
bool _allocation_done = false;
228+
};
220229
}
221230

222231
#endif

src/backends/torch/torchlib.cc

+131-39
Original file line numberDiff line numberDiff line change
@@ -163,29 +163,119 @@ namespace dd
163163
_classif->to(device, dtype);
164164
}
165165

166+
void TorchModule::proto_model_load(const TorchModel &model)
167+
{
168+
_logger->info("loading " + model._proto);
169+
try
170+
{
171+
_graph = std::make_shared<CaffeToTorch>(model._proto);
172+
}
173+
catch (std::exception &e)
174+
{
175+
_logger->info("unable to load " + model._proto);
176+
throw;
177+
}
178+
}
179+
180+
void TorchModule::graph_model_load(const TorchModel &tmodel)
181+
{
182+
if (!tmodel._traced.empty() && _graph->needs_reload())
183+
{
184+
_logger->info("loading " + tmodel._traced);
185+
try
186+
{
187+
torch::load(_graph, tmodel._traced, _device);
188+
}
189+
catch (std::exception &e)
190+
{
191+
_logger->error("unable to load " + tmodel._traced);
192+
throw;
193+
}
194+
}
195+
}
196+
197+
void TorchModule::native_model_load(const TorchModel &tmodel)
198+
{
199+
if (!tmodel._native.empty())
200+
{
201+
_logger->info("loading " + tmodel._native);
202+
try
203+
{
204+
torch::load(_native, tmodel._native);
205+
}
206+
catch (std::exception &e)
207+
{
208+
_logger->error("unable to load " + tmodel._native);
209+
throw;
210+
}
211+
}
212+
}
213+
214+
void TorchModule::classif_model_load(const TorchModel &model)
215+
{
216+
_logger->info("loading " + model._weights);
217+
try
218+
{
219+
torch::load(_classif, model._weights, _device);
220+
}
221+
catch (std::exception &e)
222+
{
223+
_logger->error("unable to load " + model._weights);
224+
throw;
225+
}
226+
}
227+
228+
void TorchModule::classif_layer_load()
229+
{
230+
if (!_classif_layer_file.empty())
231+
{
232+
_logger->info("loading " + _classif_layer_file);
233+
torch::load(_classif, _classif_layer_file, _device);
234+
}
235+
}
236+
237+
void TorchModule::traced_model_load(TorchModel &model)
238+
{
239+
_logger->info("loading " + model._traced);
240+
try
241+
{
242+
_traced = std::make_shared<torch::jit::script::Module>(
243+
torch::jit::load(model._traced, _device));
244+
}
245+
catch (std::exception &e)
246+
{
247+
_logger->error("unable to load " + model._traced);
248+
throw;
249+
}
250+
}
251+
166252
template <class TInputConnectorStrategy>
167253
void TorchModule::post_transform(const std::string tmpl,
168254
const APIData &template_params,
169255
const TInputConnectorStrategy &inputc,
170256
const TorchModel &tmodel,
171257
const torch::Device &device)
172258
{
259+
_device = device;
173260
this->_native = std::shared_ptr<NativeModule>(
174261
NativeFactory::from_template<TInputConnectorStrategy>(
175262
tmpl, template_params, inputc));
176263

177264
if (_native)
178-
if (!tmodel._native.empty())
179-
torch::load(_native, tmodel._native, device);
265+
{
266+
_logger->info("created net using template " + tmpl);
267+
native_model_load(tmodel);
268+
}
180269

181270
if (_graph)
182271
{
183272
std::vector<long int> dims = inputc._dataset.datasize(0);
184273
dims.insert(dims.begin(), 1); // dummy batch size
185274
_graph->finalize(dims);
275+
if (_graph->needs_reload())
276+
_logger->info("net was reallocated due to input dim changes");
186277
// reload params after finalize
187-
if (!tmodel._traced.empty())
188-
torch::load(_graph, tmodel._traced, _device);
278+
graph_model_load(tmodel);
189279
}
190280
to(_device);
191281
}
@@ -361,11 +451,7 @@ namespace dd
361451
// First dimension is batch id
362452
int outdim = to_tensor_safe(forward(input_example)).sizes()[1];
363453
_classif = torch::nn::Linear(outdim, nclasses);
364-
365-
if (!_classif_layer_file.empty())
366-
{
367-
torch::load(_classif, _classif_layer_file, _device);
368-
}
454+
classif_layer_load();
369455
}
370456

371457
std::vector<Tensor> TorchModule::parameters()
@@ -401,13 +487,13 @@ namespace dd
401487
void TorchModule::load(TorchModel &model)
402488
{
403489
if (!model._traced.empty() && model._proto.empty())
404-
_traced = std::make_shared<torch::jit::script::Module>(
405-
torch::jit::load(model._traced, _device));
490+
traced_model_load(model);
491+
406492
if (!model._weights.empty())
407493
{
408494
if (_classif)
409495
{
410-
torch::load(_classif, model._weights, _device);
496+
classif_model_load(model);
411497
}
412498
else if (_require_classif_layer)
413499
{
@@ -416,16 +502,12 @@ namespace dd
416502
}
417503
if (!model._proto.empty())
418504
{
419-
_graph = std::make_shared<CaffeToTorch>(model._proto);
420-
if (!model._traced.empty())
421-
torch::load(_graph, model._traced, _device);
505+
proto_model_load(model);
506+
graph_model_load(model);
422507
}
508+
423509
if (!model._native.empty())
424-
{
425-
std::shared_ptr<NativeModule> m;
426-
torch::load(m, model._native);
427-
_native = m;
428-
}
510+
native_model_load(model);
429511
}
430512

431513
void TorchModule::eval()
@@ -544,6 +626,33 @@ namespace dd
544626
}
545627
}
546628

629+
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
630+
class TMLModel>
631+
void
632+
TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
633+
TMLModel>::solver_load(std::unique_ptr<optim::Optimizer> &optimizer)
634+
{
635+
if (!this->_mlmodel._sstate.empty())
636+
{
637+
638+
this->_logger->info("Reload solver from {}", this->_mlmodel._sstate);
639+
size_t start = this->_mlmodel._sstate.rfind("-") + 1;
640+
size_t end = this->_mlmodel._sstate.rfind(".");
641+
int it = std::stoi(this->_mlmodel._sstate.substr(start, end - start));
642+
this->_logger->info("Restarting optimization from iter {}", it);
643+
this->_logger->info("loading " + this->_mlmodel._sstate);
644+
try
645+
{
646+
torch::load(*optimizer, this->_mlmodel._sstate);
647+
}
648+
catch (std::exception &e)
649+
{
650+
this->_logger->error("unable to load " + this->_mlmodel._sstate);
651+
throw;
652+
}
653+
}
654+
}
655+
547656
/*- from mllib -*/
548657
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
549658
class TMLModel>
@@ -581,6 +690,7 @@ namespace dd
581690
_device = gpu ? torch::Device(DeviceType::CUDA, gpuid)
582691
: torch::Device(DeviceType::CPU);
583692
_module._device = _device;
693+
_module._logger = this->_logger;
584694

585695
if (_template.find("recurrent") != std::string::npos)
586696
{
@@ -665,15 +775,6 @@ namespace dd
665775
}
666776

667777
// Load weights
668-
if (!this->_mlmodel._traced.empty())
669-
this->_logger->info("Loading ml model from file {}.",
670-
this->_mlmodel._traced);
671-
if (!this->_mlmodel._proto.empty())
672-
this->_logger->info("Loading ml model from file {}.",
673-
this->_mlmodel._proto);
674-
if (!this->_mlmodel._weights.empty())
675-
this->_logger->info("Loading weights from file {}.",
676-
this->_mlmodel._weights);
677778
_module.load(this->_mlmodel);
678779
_module.freeze_traced(freeze_traced);
679780

@@ -919,15 +1020,7 @@ namespace dd
9191020

9201021
int it = 0;
9211022
// reload solver and set it value accordingly
922-
if (!this->_mlmodel._sstate.empty())
923-
{
924-
this->_logger->info("Reload solver from {}", this->_mlmodel._sstate);
925-
size_t start = this->_mlmodel._sstate.rfind("-") + 1;
926-
size_t end = this->_mlmodel._sstate.rfind(".");
927-
it = std::stoi(this->_mlmodel._sstate.substr(start, end - start));
928-
this->_logger->info("Restarting optimization from iter {}", it);
929-
torch::load(*optimizer, this->_mlmodel._sstate);
930-
}
1023+
solver_load(optimizer);
9311024
optimizer->zero_grad();
9321025
_module.train();
9331026

@@ -1422,7 +1515,6 @@ namespace dd
14221515
unsupo.finalize(ad.getobj("parameters").getobj("output"), out,
14231516
static_cast<MLModel *>(&this->_mlmodel));
14241517
}
1425-
14261518
out.add("status", 0);
14271519
return 0;
14281520
}

src/backends/torch/torchlib.h

+13
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,16 @@ namespace dd
142142
the file where the weights are stored */
143143
unsigned int _nclasses = 0;
144144

145+
std::shared_ptr<spdlog::logger> _logger; /**< mllib logger. */
146+
145147
private:
146148
bool _freeze_traced = false; /**< Freeze weights of the traced module */
149+
void proto_model_load(const TorchModel &tmodel);
150+
void graph_model_load(const TorchModel &tmodel);
151+
void native_model_load(const TorchModel &tmodel);
152+
void classif_model_load(const TorchModel &tmodel);
153+
void traced_model_load(TorchModel &model);
154+
void classif_layer_load();
147155
};
148156

149157
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -203,6 +211,11 @@ namespace dd
203211

204212
void snapshot(int64_t elapsed_it, torch::optim::Optimizer &optimizer);
205213

214+
/**
215+
* \brief (re) load solver state
216+
*/
217+
void solver_load(std::unique_ptr<torch::optim::Optimizer> &optimizer);
218+
206219
void remove_model(int64_t it);
207220

208221
double unscale(double val, unsigned int k,

src/caffegraphinput.cc

+7-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include <fcntl.h>
2929
#include <unistd.h>
3030

31+
#include "mllibstrategy.h"
32+
3133
using google::protobuf::io::CodedInputStream;
3234
using google::protobuf::io::CodedOutputStream;
3335
using google::protobuf::io::FileInputStream;
@@ -177,19 +179,19 @@ namespace dd
177179
return true;
178180
}
179181

180-
int CaffeGraphInput::from_proto(std::string filename)
182+
void CaffeGraphInput::from_proto(std::string filename)
181183
{
182184
caffe::NetParameter net;
183185
if (!read_proto(filename, &net))
184-
return -1;
186+
throw MLLibBadParamException("unable to parse protofile");
185187

186188
bool simple_lstm = is_simple_lstm(net);
187189
if (simple_lstm)
188190
{
189191
parse_simple_lstm(net);
190-
return 0;
192+
return;
191193
}
192-
return 0;
194+
throw MLLibBadParamException(
195+
"proto file do not contain a proper LSTM/autoencoder");
193196
}
194-
195197
}

src/caffegraphinput.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ namespace dd
5353
/**
5454
* create basegraph from proto
5555
*/
56-
int from_proto(std::string filename);
56+
void from_proto(std::string filename);
5757

5858
/**
5959
* read protofile

0 commit comments

Comments
 (0)