Skip to content

Commit 1a9af3e

Browse files
benizsileht
authored andcommitted
fix: add support and automated processing of categorical variables in timeseries data
1 parent 69ff0fb commit 1a9af3e

6 files changed

+269
-125
lines changed

src/backends/caffe/caffeinputconns.cc

+76-39
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,7 @@ namespace dd
16691669
{
16701670
if (_cifc)
16711671
{
1672+
_cifc->_columns.clear();
16721673
std::string test_file = _cifc->_csv_test_fname;
16731674
_cifc->_csv_test_fname = "";
16741675
_cifc->read_csv(fname);
@@ -1697,63 +1698,100 @@ namespace dd
16971698
return 0;
16981699
}
16991700

1700-
int DDCCsvTS::read_dir(const std::string &dir, bool is_test_data,
1701-
bool update_bounds)
1701+
int DDCCsvTS::read_dir(const std::string &dir)
17021702
{
1703-
// first recursive list csv files
1704-
std::unordered_set<std::string> allfiles;
1705-
int ret = fileops::list_directory(dir, true, false, true, allfiles);
1703+
//- list all CSV files in directory
1704+
std::unordered_set<std::string> trainfiles;
1705+
int ret = fileops::list_directory(dir, true, false, true, trainfiles);
17061706
if (ret != 0)
17071707
return ret;
17081708
// then simply read them
17091709
if (!_cifc)
17101710
return -1;
17111711

1712-
if (update_bounds && _cifc->_scale
1713-
&& (_cifc->_min_vals.empty() || _cifc->_max_vals.empty()))
1712+
//- pick one file up and read header once
1713+
std::string fname = (*trainfiles.begin());
1714+
std::ifstream csv_file(fname, std::ios::binary);
1715+
if (!csv_file.is_open())
1716+
throw InputConnectorBadParamException("cannot open file " + fname);
1717+
std::string hline;
1718+
std::getline(csv_file, hline);
1719+
_cifc->read_header(hline);
1720+
1721+
//- read all test files
1722+
std::unordered_set<std::string> testfiles;
1723+
if (!_cifc->_csv_test_fname.empty())
1724+
fileops::list_directory(_cifc->_csv_test_fname, true, false, true,
1725+
testfiles);
1726+
1727+
std::unordered_set<std::string> allfiles = trainfiles;
1728+
1729+
//- aggregate all files = train + test
1730+
allfiles.insert(testfiles.begin(), testfiles.end());
1731+
1732+
//- read categoricals first if any as it affects the number of columns (and
1733+
// thus bounds)
1734+
if (!_cifc->_categoricals.empty())
17141735
{
1715-
std::unordered_set<std::string> reallyallfiles;
1716-
ret = fileops::list_directory(_cifc->_csv_test_fname, true, false,
1717-
true, reallyallfiles);
1718-
reallyallfiles.insert(allfiles.begin(), allfiles.end());
1736+
std::unordered_map<std::string, CCategorical> categoricals;
1737+
for (auto fname : allfiles)
1738+
{
1739+
csv_file = std::ifstream(fname, std::ios::binary);
1740+
if (!csv_file.is_open())
1741+
throw InputConnectorBadParamException("cannot open file "
1742+
+ fname);
1743+
std::string hline;
1744+
std::getline(csv_file, hline); // skip header
1745+
1746+
// read on categoricals
1747+
_cifc->fillup_categoricals(csv_file);
1748+
_cifc->merge_categoricals(categoricals);
1749+
}
1750+
}
17191751

1720-
std::vector<double> min_vals = _cifc->_min_vals;
1721-
std::vector<double> max_vals = _cifc->_max_vals;
1722-
for (auto fname : reallyallfiles)
1752+
//- read bounds across all TS CSV files
1753+
if (_cifc->_scale
1754+
&& (_cifc->_min_vals.empty() || _cifc->_max_vals.empty()))
1755+
{
1756+
std::vector<double> min_vals(_cifc->_min_vals);
1757+
std::vector<double> max_vals(_cifc->_max_vals);
1758+
for (auto fname : allfiles)
17231759
{
1724-
std::pair<std::vector<double>, std::vector<double>> mm
1725-
= _cifc->get_min_max_vals(fname);
1726-
if (min_vals.empty())
1727-
min_vals = mm.first;
1728-
else
1729-
for (size_t j = 0; j < mm.first.size(); j++)
1730-
min_vals.at(j) = std::min(mm.first.at(j), min_vals.at(j));
1731-
if (max_vals.empty())
1732-
max_vals = mm.second;
1733-
else
1734-
for (size_t j = 0; j < mm.first.size(); j++)
1735-
max_vals.at(j) = std::max(mm.second.at(j), max_vals.at(j));
1760+
csv_file = std::ifstream(fname, std::ios::binary);
1761+
if (!csv_file.is_open())
1762+
throw InputConnectorBadParamException("cannot open file "
1763+
+ fname);
1764+
std::string hline;
1765+
std::getline(csv_file, hline); // skip header
1766+
1767+
//- read bounds min/max
1768+
_cifc->_min_vals.clear();
1769+
_cifc->_max_vals.clear();
1770+
_cifc->find_min_max(csv_file);
1771+
_cifc->merge_min_max(min_vals, max_vals);
17361772
}
1773+
1774+
//- update global bounds
17371775
_cifc->_min_vals = min_vals;
17381776
_cifc->_max_vals = max_vals;
17391777
_cifc->serialize_bounds();
17401778
}
17411779

1742-
if (!is_test_data && _cifc->_shuffle)
1780+
// shuffle training data as needed
1781+
std::vector<std::string> trainfiles_v;
1782+
for (auto fname : trainfiles)
1783+
trainfiles_v.push_back(fname);
1784+
if (_cifc->_shuffle)
17431785
{
1744-
std::vector<std::string> allfiles_v;
1745-
for (auto fname : allfiles)
1746-
allfiles_v.push_back(fname);
17471786
auto rng = std::default_random_engine();
1748-
std::shuffle(allfiles_v.begin(), allfiles_v.end(), rng);
1749-
for (auto fname : allfiles_v)
1750-
read_file(fname, is_test_data);
1787+
std::shuffle(trainfiles_v.begin(), trainfiles_v.end(), rng);
17511788
}
1752-
else
1753-
1754-
for (auto fname : allfiles)
1755-
read_file(fname, is_test_data);
17561789

1790+
for (auto fname : trainfiles_v)
1791+
read_file(fname, false);
1792+
for (auto fname : testfiles)
1793+
read_file(fname, true);
1794+
_cifc->update_columns();
17571795
return 0;
17581796
}
17591797

@@ -2034,8 +2072,7 @@ namespace dd
20342072
DDCCsvTS ddccsvts;
20352073
ddccsvts._cifc = this;
20362074
ddccsvts._adconf = ad_input;
2037-
ddccsvts.read_dir(_csv_fname, false, true);
2038-
ddccsvts.read_dir(_csv_test_fname, true, false);
2075+
ddccsvts.read_dir(_csv_fname);
20392076

20402077
_txn->Commit();
20412078
_ttxn->Commit();

src/backends/caffe/caffeinputconns.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -894,8 +894,7 @@ namespace dd
894894
int read_file(const std::string &fname, bool is_test_data = false);
895895
int read_db(const std::string &fname);
896896
int read_mem(const std::string &content);
897-
int read_dir(const std::string &dir, bool is_test_data = false,
898-
bool update_bounds = true);
897+
int read_dir(const std::string &dir);
899898

900899
DDCsvTS _ddcsvts;
901900
CSVTSCaffeInputFileConn *_cifc = nullptr;

src/csvinputfileconn.cc

+46-39
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,49 @@ namespace dd
280280
throw InputConnectorBadParamException("cannot find id column " + _id);
281281
}
282282

283+
void CSVInputFileConn::fillup_categoricals(std::ifstream &csv_file)
284+
{
285+
int l = 0;
286+
std::string hline;
287+
while (std::getline(csv_file, hline))
288+
{
289+
hline.erase(std::remove(hline.begin(), hline.end(), '\r'),
290+
hline.end());
291+
std::vector<double> vals;
292+
std::string cid;
293+
std::string col;
294+
auto hit = _columns.begin();
295+
std::unordered_set<int>::const_iterator igit;
296+
std::stringstream sh(hline);
297+
int cu = 0;
298+
while (std::getline(sh, col, _delim[0]))
299+
{
300+
if (cu >= _detect_cols)
301+
{
302+
_logger->error("line {} has more columns than headers / this "
303+
"line: {} / header: {}",
304+
l, cu, _detect_cols);
305+
_logger->error(hline);
306+
throw InputConnectorBadParamException(
307+
"line has more columns than headers");
308+
}
309+
if ((igit = _ignored_columns_pos.find(cu))
310+
!= _ignored_columns_pos.end())
311+
{
312+
++cu;
313+
continue;
314+
}
315+
update_category((*hit), col);
316+
++hit;
317+
++cu;
318+
}
319+
++l;
320+
}
321+
csv_file.clear();
322+
csv_file.seekg(0, std::ios::beg);
323+
std::getline(csv_file, hline); // skip header line
324+
}
325+
283326
void CSVInputFileConn::find_min_max(std::ifstream &csv_file)
284327
{
285328
int nlines = 0;
@@ -330,42 +373,7 @@ namespace dd
330373
// categorical variables
331374
if (_train && !_categoricals.empty())
332375
{
333-
int l = 0;
334-
while (std::getline(csv_file, hline))
335-
{
336-
hline.erase(std::remove(hline.begin(), hline.end(), '\r'),
337-
hline.end());
338-
std::vector<double> vals;
339-
std::string cid;
340-
std::string col;
341-
auto hit = _columns.begin();
342-
std::unordered_set<int>::const_iterator igit;
343-
std::stringstream sh(hline);
344-
int cu = 0;
345-
while (std::getline(sh, col, _delim[0]))
346-
{
347-
if (cu >= _detect_cols)
348-
{
349-
_logger->error("line {} has more columns than headers", l);
350-
_logger->error(hline);
351-
throw InputConnectorBadParamException(
352-
"line has more columns than headers");
353-
}
354-
if ((igit = _ignored_columns_pos.find(cu))
355-
!= _ignored_columns_pos.end())
356-
{
357-
++cu;
358-
continue;
359-
}
360-
update_category((*hit), col);
361-
++hit;
362-
++cu;
363-
}
364-
++l;
365-
}
366-
csv_file.clear();
367-
csv_file.seekg(0, std::ios::beg);
368-
std::getline(csv_file, hline); // skip header line
376+
fillup_categoricals(csv_file);
369377
}
370378

371379
// scaling to [0,1]
@@ -397,8 +405,8 @@ namespace dd
397405
// debug
398406
/*std::cout << "csv data line #" << nlines << "= " << vals.size() <<
399407
std::endl;
400-
std::copy(vals.begin(),vals.end(),std::ostream_iterator<double>(std::cout,"
401-
")); std::cout << std::endl;*/
408+
std::copy(vals.begin(),vals.end(),std::ostream_iterator<double>(std::cout,""));
409+
std::cout << std::endl;*/
402410
// debug
403411
}
404412
_logger->info("read {} lines from {}", nlines, fname);
@@ -454,5 +462,4 @@ namespace dd
454462
if (!_ignored_columns.empty() || !_categoricals.empty())
455463
update_columns();
456464
}
457-
458465
}

src/csvinputfileconn.h

+2
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ namespace dd
482482

483483
void read_header(std::string &hline);
484484

485+
void fillup_categoricals(std::ifstream &csv_file);
486+
485487
void read_csv_line(const std::string &hline, const std::string &delim,
486488
std::vector<double> &vals, std::string &column_id,
487489
int &nlines);

0 commit comments

Comments
 (0)