@@ -248,6 +248,7 @@ namespace dd
248
248
ad.add (" height" , this ->_inputc .height ());
249
249
}
250
250
}
251
+ this ->_stats .to (ad);
251
252
return ad;
252
253
}
253
254
@@ -283,7 +284,9 @@ namespace dd
283
284
static_cast <long int >(this ->_mem_used_train * sizeof (float )));
284
285
stats.add (" data_mem_test" ,
285
286
static_cast <long int >(this ->_mem_used_test * sizeof (float )));
286
- ad.add (" stats" , stats);
287
+ ad.add (" stats" , stats); // FIXME(sileht): deprecated name, delete me when
288
+ // platform use the new name
289
+ ad.add (" model_stats" , stats);
287
290
ad.add (" jobs" , vad);
288
291
ad.add (" parameters" , _init_parameters);
289
292
ad.add (" repository" , this ->_inputc ._model_repo );
@@ -292,6 +295,7 @@ namespace dd
292
295
ad.add (" type" , std::string (" unsupervised" ));
293
296
else
294
297
ad.add (" type" , std::string (" supervised" ));
298
+ this ->_stats .to (ad);
295
299
return ad;
296
300
}
297
301
@@ -495,36 +499,29 @@ namespace dd
495
499
*/
496
500
int predict_job (const APIData &ad, APIData &out, const bool &chain = false )
497
501
{
498
- // TODO: collect input transformed data for chain, store it here in
499
- // memory
500
- // -> beware, the input connector is a copy...
502
+ if (!_train_mutex. try_lock_shared ())
503
+ throw MLServiceLockException (
504
+ " Predict call while training with an offline learning algorithm " );
501
505
502
- if (!this ->_online )
506
+ this ->_stats .predict_start ();
507
+
508
+ int err = 0 ;
509
+ try
503
510
{
504
- if (!_train_mutex.try_lock_shared ())
505
- throw MLServiceLockException (" Predict call while training with an "
506
- " offline learning algorithm" );
507
- int err = 0 ;
508
- try
509
- {
510
- if (chain)
511
- const_cast <APIData &>(ad).add (" chain" , true );
512
- err = this ->predict (ad, out);
513
- }
514
- catch (std::exception &e)
515
- {
516
- _train_mutex.unlock_shared ();
517
- throw ;
518
- }
519
- _train_mutex.unlock_shared ();
520
- return err;
511
+ if (chain)
512
+ const_cast <APIData &>(ad).add (" chain" , true );
513
+ err = this ->predict (ad, out);
521
514
}
522
- else // wait til a lock can be acquired
515
+ catch (std:: exception &e)
523
516
{
524
- boost::shared_lock<boost::shared_mutex> lock (_train_mutex);
525
- return this ->predict (ad, out);
517
+ _train_mutex.unlock_shared ();
518
+ this ->_stats .predict_end (false );
519
+ throw ;
526
520
}
527
- return 0 ;
521
+ this ->_stats .predict_end (true );
522
+
523
+ _train_mutex.unlock_shared ();
524
+ return err;
528
525
}
529
526
530
527
std::string _sname; /* *< service name. */
0 commit comments