Skip to content

Commit aa30e88

Browse files
fantessileht
authored andcommitted
fix(torch/timeseries): unscale prediction output if needed
1 parent 19e9674 commit aa30e88

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/backends/torch/torchlib.cc

+42-1
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,46 @@ namespace dd
502502
empty_cuda_cache();
503503
}
504504

505+
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
506+
class TMLModel>
507+
double TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
508+
TMLModel>::unscale(double val, unsigned int k,
509+
const TInputConnectorStrategy &inputc)
510+
{
511+
(void)inputc;
512+
(void)k;
513+
// unscaling is input connector specific
514+
return val;
515+
}
516+
517+
// full template specialization
518+
template <>
519+
double
520+
TorchLib<CSVTSTorchInputFileConn, SupervisedOutput, TorchModel>::unscale(
521+
double val, unsigned int k, const CSVTSTorchInputFileConn &inputc)
522+
523+
{
524+
if (inputc._min_vals.empty() || inputc._max_vals.empty())
525+
{
526+
this->_logger->info("not unscaling output because no bounds "
527+
"data found");
528+
return val;
529+
}
530+
else
531+
{
532+
533+
if (!inputc._dont_scale_labels)
534+
{
535+
double max = inputc._max_vals[inputc._label_pos[k]];
536+
double min = inputc._min_vals[inputc._label_pos[k]];
537+
if (inputc._scale_between_minus1_and_1)
538+
val += 0.5;
539+
val = val * (max - min) + min;
540+
}
541+
return val;
542+
}
543+
}
544+
505545
/*- from mllib -*/
506546
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
507547
class TMLModel>
@@ -1342,7 +1382,8 @@ namespace dd
13421382
for (unsigned int k = 0; k < this->_inputc._ntargets;
13431383
++k)
13441384
{
1345-
preds.push_back(output_acc[j][t][k]);
1385+
double res = output_acc[j][t][k];
1386+
preds.push_back(unscale(res, k, inputc));
13461387
}
13471388
APIData ts;
13481389
ts.add("out", preds);

src/backends/torch/torchlib.h

+3
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ namespace dd
204204
void snapshot(int64_t elapsed_it, torch::optim::Optimizer &optimizer);
205205

206206
void remove_model(int64_t it);
207+
208+
double unscale(double val, unsigned int k,
209+
const TInputConnectorStrategy &inputc);
207210
};
208211
}
209212

0 commit comments

Comments
 (0)