22
22
23
23
#include " torchinputconns.h"
24
24
25
+ #include " utils/utils.hpp"
26
+
25
27
namespace dd
26
28
{
27
29
@@ -30,7 +32,7 @@ namespace dd
30
32
void TorchInputInterface::build_test_datadb_from_full_datadb (double tsplit)
31
33
{
32
34
_tilogger->info (" splitting : using {} of dataset as test set" , tsplit);
33
- _dataset.reset (db::WRITE);
35
+ _dataset.reset (true , db::WRITE);
34
36
std::vector<int64_t > indicestest;
35
37
int64_t ntest = _dataset._indices .size () * tsplit;
36
38
auto seed = static_cast <long >(time (NULL ));
@@ -85,6 +87,21 @@ namespace dd
85
87
return true ;
86
88
}
87
89
90
+ std::vector<c10::IValue>
91
+ TorchInputInterface::get_input_example (torch::Device device)
92
+ {
93
+ _dataset.reset ();
94
+ auto batchopt = _dataset.get_batch ({ 1 });
95
+ TorchBatch batch = batchopt.value ();
96
+ std::vector<c10::IValue> input_example;
97
+
98
+ for (auto &t : batch.data )
99
+ {
100
+ input_example.push_back (t.to (device));
101
+ }
102
+ return input_example;
103
+ }
104
+
88
105
// ===== TorchDataset
89
106
90
107
void TorchDataset::finalize_db ()
@@ -155,9 +172,8 @@ namespace dd
155
172
_indices.push_back (index );
156
173
}
157
174
158
- void TorchDataset::write_tensors_to_db (
159
- std::vector<at::Tensor> data,
160
- __attribute__ ((unused)) std::vector<at::Tensor> target)
175
+ void TorchDataset::write_tensors_to_db (const std::vector<at::Tensor> &data,
176
+ const std::vector<at::Tensor> &target)
161
177
{
162
178
std::ostringstream dstream;
163
179
torch::save (data, dstream);
@@ -180,7 +196,7 @@ namespace dd
180
196
_txn->Put (data_key.str (), dstream.str ());
181
197
_txn->Put (target_key.str (), tstream.str ());
182
198
183
- // should not commit transations every time;
199
+ // should not commit transactions every time;
184
200
if (++_current_index % _batches_per_transaction == 0 )
185
201
{
186
202
_txn->Commit ();
@@ -189,8 +205,8 @@ namespace dd
189
205
}
190
206
}
191
207
192
- void TorchDataset::add_batch (std::vector<at::Tensor> data,
193
- std::vector<at::Tensor> target)
208
+ void TorchDataset::add_batch (const std::vector<at::Tensor> & data,
209
+ const std::vector<at::Tensor> & target)
194
210
{
195
211
if (!_db)
196
212
_batches.push_back (TorchBatch (data, target));
@@ -369,6 +385,277 @@ namespace dd
369
385
return new_dataset;
370
386
}
371
387
388
+ // ===== ImgTorchInputFileConn
389
+
390
+ void ImgTorchInputFileConn::read_image_folder (
391
+ std::vector<std::pair<std::string, int >> &lfiles,
392
+ std::unordered_map<int , std::string> &hcorresp,
393
+ std::unordered_map<std::string, int > &hcorresp_r,
394
+ const std::string &folderPath)
395
+ {
396
+
397
+ // TODO Put file parsing from caffe in common files to use it in other
398
+ // backends
399
+ int cl = 0 ;
400
+
401
+ std::unordered_set<std::string> subdirs;
402
+ if (fileops::list_directory (folderPath, false , true , false , subdirs))
403
+ throw InputConnectorBadParamException (
404
+ " failed reading image train data directory " + folderPath);
405
+
406
+ auto uit = subdirs.begin ();
407
+ while (uit != subdirs.end ())
408
+ {
409
+ std::unordered_set<std::string> subdir_files;
410
+ if (fileops::list_directory ((*uit), true , false , true , subdir_files))
411
+ throw InputConnectorBadParamException (
412
+ " failed reading image train data sub-directory " + (*uit));
413
+ std::string cls = dd_utils::split ((*uit), ' /' ).back ();
414
+ hcorresp.insert (std::pair<int , std::string>(cl, cls));
415
+ hcorresp_r.insert (std::pair<std::string, int >(cls, cl));
416
+ auto fit = subdir_files.begin ();
417
+ while (
418
+ fit
419
+ != subdir_files.end ()) // XXX: re-iterating the file is not optimal
420
+ {
421
+ lfiles.push_back (std::pair<std::string, int >((*fit), cl));
422
+ ++fit;
423
+ }
424
+ ++cl;
425
+ ++uit;
426
+ }
427
+ }
428
+
429
+ void ImgTorchInputFileConn::transform (const APIData &ad)
430
+ {
431
+ if (!_train)
432
+ {
433
+ try
434
+ {
435
+ ImgInputFileConn::transform (ad);
436
+ }
437
+ catch (InputConnectorBadParamException &e)
438
+ {
439
+ throw ;
440
+ }
441
+
442
+ // XXX: No predict from db yet
443
+ _dataset.set_dbParams (false , " " , " " );
444
+
445
+ for (size_t i = 0 ; i < this ->_images .size (); ++i)
446
+ {
447
+ _dataset.add_batch ({ image_to_tensor (this ->_images [i]) });
448
+ }
449
+ }
450
+ else // if (!_train)
451
+ {
452
+ // This must be done since we don't call ImgInputFileConn::transform
453
+ if (ad.has (" parameters" )) // overriding default parameters
454
+ {
455
+ APIData ad_param = ad.getobj (" parameters" );
456
+ if (ad_param.has (" input" ))
457
+ {
458
+ fillup_parameters (ad_param.getobj (" input" ));
459
+ }
460
+ }
461
+
462
+ // Read all parsed files and create tensor datasets
463
+ bool createDb
464
+ = _db && TorchInputInterface::has_to_create_db (ad, _test_split);
465
+ bool shouldLoad = !_db || createDb;
466
+
467
+ if (shouldLoad)
468
+ {
469
+ if (_db)
470
+ _tilogger->info (" Load from db" );
471
+ // Get files paths
472
+ try
473
+ {
474
+ get_data (ad);
475
+ }
476
+ catch (InputConnectorBadParamException &e)
477
+ {
478
+ throw ;
479
+ }
480
+
481
+ bool dir_images = true ;
482
+ fileops::file_exists (_uris.at (0 ), dir_images);
483
+
484
+ // Parse URIs and retrieve images
485
+ std::unordered_map<int , std::string>
486
+ hcorresp; // correspondence class number / class name
487
+ std::unordered_map<std::string, int >
488
+ hcorresp_r; // reverse correspondence for test set.
489
+ std::vector<std::pair<std::string, int >> lfiles; // labeled files
490
+ std::vector<std::pair<std::string, int >> test_lfiles;
491
+
492
+ bool folder = dir_images;
493
+
494
+ if (folder)
495
+ {
496
+ read_image_folder (lfiles, hcorresp, hcorresp_r, _uris.at (0 ));
497
+ // TODO manage test split
498
+
499
+ if (_uris.size () > 1 )
500
+ {
501
+ std::unordered_map<int , std::string>
502
+ test_hcorresp; // correspondence class number / class
503
+ // name
504
+ std::unordered_map<std::string, int >
505
+ test_hcorresp_r; // reverse correspondence for test
506
+ // set.
507
+
508
+ read_image_folder (test_lfiles, test_hcorresp,
509
+ test_hcorresp_r, _uris.at (1 ));
510
+ }
511
+ }
512
+ else
513
+ {
514
+ throw InputConnectorBadParamException (
515
+ " Torch image input connector expects folders" );
516
+ }
517
+
518
+ bool has_test_data = test_lfiles.size () != 0 ;
519
+
520
+ if (_test_split > 0.0 && !has_test_data)
521
+ {
522
+ // TODO Code for shuffling based on seed / splitting should
523
+ // should be put in a common place
524
+
525
+ // shuffle
526
+ std::mt19937 g;
527
+ if (_seed >= 0 )
528
+ g = std::mt19937 (_seed);
529
+ else
530
+ {
531
+ std::random_device rd;
532
+ g = std::mt19937 (rd ());
533
+ }
534
+ std::shuffle (lfiles.begin (), lfiles.end (), g);
535
+
536
+ // Split
537
+ int split_pos
538
+ = std::floor (lfiles.size () * (1.0 - _test_split));
539
+
540
+ auto split_begin = lfiles.begin ();
541
+ std::advance (split_begin, split_pos);
542
+ test_lfiles.insert (test_lfiles.begin (), split_begin,
543
+ lfiles.end ());
544
+ lfiles.erase (split_begin, lfiles.end ());
545
+
546
+ _logger->info (
547
+ " data split test size={} / remaining data size={}" ,
548
+ test_lfiles.size (), lfiles.size ());
549
+ }
550
+
551
+ // Read data
552
+ for (const std::pair<std::string, int > &lfile : lfiles)
553
+ {
554
+ add_image_file (_dataset, lfile.first , lfile.second );
555
+ }
556
+
557
+ for (const std::pair<std::string, int > &lfile : test_lfiles)
558
+ {
559
+ add_image_file (_test_dataset, lfile.first , lfile.second );
560
+ }
561
+
562
+ // Write corresp file
563
+ std::ofstream correspf (_model_repo + " /" + _correspname,
564
+ std::ios::binary);
565
+ auto hit = hcorresp.begin ();
566
+ while (hit != hcorresp.end ())
567
+ {
568
+ correspf << (*hit).first << " " << (*hit).second << std::endl;
569
+ ++hit;
570
+ }
571
+ correspf.close ();
572
+ }
573
+
574
+ if (createDb)
575
+ {
576
+ _dataset.finalize_db ();
577
+ _test_dataset.finalize_db ();
578
+ }
579
+ }
580
+ }
581
+
582
+ int ImgTorchInputFileConn::add_image_file (TorchDataset &dataset,
583
+ const std::string &fname,
584
+ int target)
585
+ {
586
+ DDImg dimg;
587
+ dimg._bw = _bw;
588
+ dimg._rgb = _rgb;
589
+ dimg._histogram_equalization = _histogram_equalization;
590
+ dimg._unchanged_data = _unchanged_data;
591
+ dimg._crop_width = _crop_width;
592
+ dimg._crop_height = _crop_height;
593
+ dimg._scale = _scale;
594
+ dimg._scaled = _scaled;
595
+ dimg._scale_min = _scale_min;
596
+ dimg._scale_max = _scale_max;
597
+ dimg._keep_orig = _keep_orig;
598
+ dimg._interp = _interp;
599
+
600
+ dimg._width = _width;
601
+ dimg._height = _height;
602
+
603
+ try
604
+ {
605
+ if (dimg.read_file (fname))
606
+ {
607
+ this ->_logger ->error (" Uri failed: {}" , fname);
608
+ }
609
+ }
610
+ catch (std::exception &e)
611
+ {
612
+ this ->_logger ->error (" Uri failed: {}" , fname);
613
+ }
614
+ if (dimg._imgs .size () != 0 )
615
+ {
616
+ at::Tensor imgt = image_to_tensor (dimg._imgs [0 ]);
617
+ at::Tensor targett{ torch::full (1 , target, torch::kLong ) };
618
+
619
+ dataset.add_batch ({ imgt }, { targett });
620
+ return 0 ;
621
+ }
622
+ else
623
+ {
624
+ return -1 ;
625
+ }
626
+ }
627
+
628
+ at::Tensor ImgTorchInputFileConn::image_to_tensor (const cv::Mat &bgr)
629
+ {
630
+ std::vector<int64_t > sizes{ _height, _width, bgr.channels () };
631
+ at::TensorOptions options (at::ScalarType::Byte );
632
+
633
+ at::Tensor imgt = torch::from_blob (bgr.data , at::IntList (sizes), options);
634
+ imgt = imgt.toType (at::kFloat ).permute ({ 2 , 0 , 1 });
635
+ size_t nchannels = imgt.size (0 );
636
+
637
+ if (_scale != 1.0 )
638
+ imgt = imgt.mul (_scale);
639
+
640
+ if (!_mean.empty () && _mean.size () != nchannels)
641
+ throw InputConnectorBadParamException (
642
+ " mean vector be of size the number of channels ("
643
+ + std::to_string (nchannels) + " )" );
644
+
645
+ for (size_t m = 0 ; m < _mean.size (); m++)
646
+ imgt[0 ][m] = imgt[0 ][m].sub_ (_mean.at (m));
647
+
648
+ if (!_std.empty () && _std.size () != nchannels)
649
+ throw InputConnectorBadParamException (
650
+ " std vector be of size the number of channels ("
651
+ + std::to_string (nchannels) + " )" );
652
+
653
+ for (size_t s = 0 ; s < _std.size (); s++)
654
+ imgt[0 ][s] = imgt[0 ][s].div_ (_std.at (s));
655
+
656
+ return imgt;
657
+ }
658
+
372
659
// ===== TxtTorchInputFileConn
373
660
374
661
void TxtTorchInputFileConn::parse_content (const std::string &content,
0 commit comments