Skip to content

Commit 6e81915

Browse files
Bycobsileht
authored andcommitted
feat(dede): Training for image classification with torch
Add db support Add test split for raw data without db Download archive with other tests Fix weight loading for classification head Add predict call in tests Add test to load db with train/test directories
1 parent ca58c51 commit 6e81915

7 files changed

+558
-79
lines changed

src/backends/torch/torchinputconns.cc

+294-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include "torchinputconns.h"
2424

25+
#include "utils/utils.hpp"
26+
2527
namespace dd
2628
{
2729

@@ -30,7 +32,7 @@ namespace dd
3032
void TorchInputInterface::build_test_datadb_from_full_datadb(double tsplit)
3133
{
3234
_tilogger->info("splitting : using {} of dataset as test set", tsplit);
33-
_dataset.reset(db::WRITE);
35+
_dataset.reset(true, db::WRITE);
3436
std::vector<int64_t> indicestest;
3537
int64_t ntest = _dataset._indices.size() * tsplit;
3638
auto seed = static_cast<long>(time(NULL));
@@ -85,6 +87,21 @@ namespace dd
8587
return true;
8688
}
8789

90+
std::vector<c10::IValue>
91+
TorchInputInterface::get_input_example(torch::Device device)
92+
{
93+
_dataset.reset();
94+
auto batchopt = _dataset.get_batch({ 1 });
95+
TorchBatch batch = batchopt.value();
96+
std::vector<c10::IValue> input_example;
97+
98+
for (auto &t : batch.data)
99+
{
100+
input_example.push_back(t.to(device));
101+
}
102+
return input_example;
103+
}
104+
88105
// ===== TorchDataset
89106

90107
void TorchDataset::finalize_db()
@@ -155,9 +172,8 @@ namespace dd
155172
_indices.push_back(index);
156173
}
157174

158-
void TorchDataset::write_tensors_to_db(
159-
std::vector<at::Tensor> data,
160-
__attribute__((unused)) std::vector<at::Tensor> target)
175+
void TorchDataset::write_tensors_to_db(const std::vector<at::Tensor> &data,
176+
const std::vector<at::Tensor> &target)
161177
{
162178
std::ostringstream dstream;
163179
torch::save(data, dstream);
@@ -180,7 +196,7 @@ namespace dd
180196
_txn->Put(data_key.str(), dstream.str());
181197
_txn->Put(target_key.str(), tstream.str());
182198

183-
// should not commit transations every time;
199+
// should not commit transactions every time;
184200
if (++_current_index % _batches_per_transaction == 0)
185201
{
186202
_txn->Commit();
@@ -189,8 +205,8 @@ namespace dd
189205
}
190206
}
191207

192-
void TorchDataset::add_batch(std::vector<at::Tensor> data,
193-
std::vector<at::Tensor> target)
208+
void TorchDataset::add_batch(const std::vector<at::Tensor> &data,
209+
const std::vector<at::Tensor> &target)
194210
{
195211
if (!_db)
196212
_batches.push_back(TorchBatch(data, target));
@@ -369,6 +385,277 @@ namespace dd
369385
return new_dataset;
370386
}
371387

388+
// ===== ImgTorchInputFileConn
389+
390+
void ImgTorchInputFileConn::read_image_folder(
391+
std::vector<std::pair<std::string, int>> &lfiles,
392+
std::unordered_map<int, std::string> &hcorresp,
393+
std::unordered_map<std::string, int> &hcorresp_r,
394+
const std::string &folderPath)
395+
{
396+
397+
// TODO Put file parsing from caffe in common files to use it in other
398+
// backends
399+
int cl = 0;
400+
401+
std::unordered_set<std::string> subdirs;
402+
if (fileops::list_directory(folderPath, false, true, false, subdirs))
403+
throw InputConnectorBadParamException(
404+
"failed reading image train data directory " + folderPath);
405+
406+
auto uit = subdirs.begin();
407+
while (uit != subdirs.end())
408+
{
409+
std::unordered_set<std::string> subdir_files;
410+
if (fileops::list_directory((*uit), true, false, true, subdir_files))
411+
throw InputConnectorBadParamException(
412+
"failed reading image train data sub-directory " + (*uit));
413+
std::string cls = dd_utils::split((*uit), '/').back();
414+
hcorresp.insert(std::pair<int, std::string>(cl, cls));
415+
hcorresp_r.insert(std::pair<std::string, int>(cls, cl));
416+
auto fit = subdir_files.begin();
417+
while (
418+
fit
419+
!= subdir_files.end()) // XXX: re-iterating the file is not optimal
420+
{
421+
lfiles.push_back(std::pair<std::string, int>((*fit), cl));
422+
++fit;
423+
}
424+
++cl;
425+
++uit;
426+
}
427+
}
428+
429+
void ImgTorchInputFileConn::transform(const APIData &ad)
430+
{
431+
if (!_train)
432+
{
433+
try
434+
{
435+
ImgInputFileConn::transform(ad);
436+
}
437+
catch (InputConnectorBadParamException &e)
438+
{
439+
throw;
440+
}
441+
442+
// XXX: No predict from db yet
443+
_dataset.set_dbParams(false, "", "");
444+
445+
for (size_t i = 0; i < this->_images.size(); ++i)
446+
{
447+
_dataset.add_batch({ image_to_tensor(this->_images[i]) });
448+
}
449+
}
450+
else // if (!_train)
451+
{
452+
// This must be done since we don't call ImgInputFileConn::transform
453+
if (ad.has("parameters")) // overriding default parameters
454+
{
455+
APIData ad_param = ad.getobj("parameters");
456+
if (ad_param.has("input"))
457+
{
458+
fillup_parameters(ad_param.getobj("input"));
459+
}
460+
}
461+
462+
// Read all parsed files and create tensor datasets
463+
bool createDb
464+
= _db && TorchInputInterface::has_to_create_db(ad, _test_split);
465+
bool shouldLoad = !_db || createDb;
466+
467+
if (shouldLoad)
468+
{
469+
if (_db)
470+
_tilogger->info("Load from db");
471+
// Get files paths
472+
try
473+
{
474+
get_data(ad);
475+
}
476+
catch (InputConnectorBadParamException &e)
477+
{
478+
throw;
479+
}
480+
481+
bool dir_images = true;
482+
fileops::file_exists(_uris.at(0), dir_images);
483+
484+
// Parse URIs and retrieve images
485+
std::unordered_map<int, std::string>
486+
hcorresp; // correspondence class number / class name
487+
std::unordered_map<std::string, int>
488+
hcorresp_r; // reverse correspondence for test set.
489+
std::vector<std::pair<std::string, int>> lfiles; // labeled files
490+
std::vector<std::pair<std::string, int>> test_lfiles;
491+
492+
bool folder = dir_images;
493+
494+
if (folder)
495+
{
496+
read_image_folder(lfiles, hcorresp, hcorresp_r, _uris.at(0));
497+
// TODO manage test split
498+
499+
if (_uris.size() > 1)
500+
{
501+
std::unordered_map<int, std::string>
502+
test_hcorresp; // correspondence class number / class
503+
// name
504+
std::unordered_map<std::string, int>
505+
test_hcorresp_r; // reverse correspondence for test
506+
// set.
507+
508+
read_image_folder(test_lfiles, test_hcorresp,
509+
test_hcorresp_r, _uris.at(1));
510+
}
511+
}
512+
else
513+
{
514+
throw InputConnectorBadParamException(
515+
"Torch image input connector expects folders");
516+
}
517+
518+
bool has_test_data = test_lfiles.size() != 0;
519+
520+
if (_test_split > 0.0 && !has_test_data)
521+
{
522+
// TODO Code for shuffling based on seed / splitting should
523+
// should be put in a common place
524+
525+
// shuffle
526+
std::mt19937 g;
527+
if (_seed >= 0)
528+
g = std::mt19937(_seed);
529+
else
530+
{
531+
std::random_device rd;
532+
g = std::mt19937(rd());
533+
}
534+
std::shuffle(lfiles.begin(), lfiles.end(), g);
535+
536+
// Split
537+
int split_pos
538+
= std::floor(lfiles.size() * (1.0 - _test_split));
539+
540+
auto split_begin = lfiles.begin();
541+
std::advance(split_begin, split_pos);
542+
test_lfiles.insert(test_lfiles.begin(), split_begin,
543+
lfiles.end());
544+
lfiles.erase(split_begin, lfiles.end());
545+
546+
_logger->info(
547+
"data split test size={} / remaining data size={}",
548+
test_lfiles.size(), lfiles.size());
549+
}
550+
551+
// Read data
552+
for (const std::pair<std::string, int> &lfile : lfiles)
553+
{
554+
add_image_file(_dataset, lfile.first, lfile.second);
555+
}
556+
557+
for (const std::pair<std::string, int> &lfile : test_lfiles)
558+
{
559+
add_image_file(_test_dataset, lfile.first, lfile.second);
560+
}
561+
562+
// Write corresp file
563+
std::ofstream correspf(_model_repo + "/" + _correspname,
564+
std::ios::binary);
565+
auto hit = hcorresp.begin();
566+
while (hit != hcorresp.end())
567+
{
568+
correspf << (*hit).first << " " << (*hit).second << std::endl;
569+
++hit;
570+
}
571+
correspf.close();
572+
}
573+
574+
if (createDb)
575+
{
576+
_dataset.finalize_db();
577+
_test_dataset.finalize_db();
578+
}
579+
}
580+
}
581+
582+
int ImgTorchInputFileConn::add_image_file(TorchDataset &dataset,
583+
const std::string &fname,
584+
int target)
585+
{
586+
DDImg dimg;
587+
dimg._bw = _bw;
588+
dimg._rgb = _rgb;
589+
dimg._histogram_equalization = _histogram_equalization;
590+
dimg._unchanged_data = _unchanged_data;
591+
dimg._crop_width = _crop_width;
592+
dimg._crop_height = _crop_height;
593+
dimg._scale = _scale;
594+
dimg._scaled = _scaled;
595+
dimg._scale_min = _scale_min;
596+
dimg._scale_max = _scale_max;
597+
dimg._keep_orig = _keep_orig;
598+
dimg._interp = _interp;
599+
600+
dimg._width = _width;
601+
dimg._height = _height;
602+
603+
try
604+
{
605+
if (dimg.read_file(fname))
606+
{
607+
this->_logger->error("Uri failed: {}", fname);
608+
}
609+
}
610+
catch (std::exception &e)
611+
{
612+
this->_logger->error("Uri failed: {}", fname);
613+
}
614+
if (dimg._imgs.size() != 0)
615+
{
616+
at::Tensor imgt = image_to_tensor(dimg._imgs[0]);
617+
at::Tensor targett{ torch::full(1, target, torch::kLong) };
618+
619+
dataset.add_batch({ imgt }, { targett });
620+
return 0;
621+
}
622+
else
623+
{
624+
return -1;
625+
}
626+
}
627+
628+
at::Tensor ImgTorchInputFileConn::image_to_tensor(const cv::Mat &bgr)
629+
{
630+
std::vector<int64_t> sizes{ _height, _width, bgr.channels() };
631+
at::TensorOptions options(at::ScalarType::Byte);
632+
633+
at::Tensor imgt = torch::from_blob(bgr.data, at::IntList(sizes), options);
634+
imgt = imgt.toType(at::kFloat).permute({ 2, 0, 1 });
635+
size_t nchannels = imgt.size(0);
636+
637+
if (_scale != 1.0)
638+
imgt = imgt.mul(_scale);
639+
640+
if (!_mean.empty() && _mean.size() != nchannels)
641+
throw InputConnectorBadParamException(
642+
"mean vector be of size the number of channels ("
643+
+ std::to_string(nchannels) + ")");
644+
645+
for (size_t m = 0; m < _mean.size(); m++)
646+
imgt[0][m] = imgt[0][m].sub_(_mean.at(m));
647+
648+
if (!_std.empty() && _std.size() != nchannels)
649+
throw InputConnectorBadParamException(
650+
"std vector be of size the number of channels ("
651+
+ std::to_string(nchannels) + ")");
652+
653+
for (size_t s = 0; s < _std.size(); s++)
654+
imgt[0][s] = imgt[0][s].div_(_std.at(s));
655+
656+
return imgt;
657+
}
658+
372659
// ===== TxtTorchInputFileConn
373660

374661
void TxtTorchInputFileConn::parse_content(const std::string &content,

0 commit comments

Comments
 (0)