Skip to content

Commit 52bd5a5

Browse files
authored
Merge pull request #1 from b-scholz/sam
Fix testcases
2 parents 82c9b3b + 4893cd6 commit 52bd5a5

26 files changed

+248
-166
lines changed

configure.ac

+1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ AC_CONFIG_FILES([
282282
tests/Makefile
283283
tests/atlocal
284284
tests/interface/functors/Makefile
285+
tests/interface/graph_coloring/Makefile
285286
])
286287
AC_CONFIG_LINKS([utilities/bash-completion/completions/souffle:debian/souffle.bash-completion])
287288

src/include/souffle/io/ReadStream.h

+78-3
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class ReadStream : public SerialisationStream<false> {
9696
consumeWhiteSpace(source, pos);
9797
switch (recordType[0]) {
9898
case 's': {
99-
recordValues[i] = symbolTable.unsafeEncode(readUntil(source, ",]", pos, &consumed));
99+
recordValues[i] = symbolTable.unsafeEncode(readSymbol(source, ",]", pos, &consumed));
100100
break;
101101
}
102102
case 'i': {
@@ -199,7 +199,7 @@ class ReadStream : public SerialisationStream<false> {
199199

200200
switch (argType[0]) {
201201
case 's': {
202-
branchArgs[i] = symbolTable.unsafeEncode(readUntil(source, ",)", pos, &consumed));
202+
branchArgs[i] = symbolTable.unsafeEncode(readSymbol(source, ",)", pos, &consumed));
203203
break;
204204
}
205205
case 'i': {
@@ -267,7 +267,7 @@ class ReadStream : public SerialisationStream<false> {
267267
return source.substr(bgn, pos - bgn);
268268
}
269269

270-
std::string readUntil(const std::string& source, const std::string stopChars, const std::size_t pos,
270+
std::string readUntil(const std::string& source, const std::string& stopChars, const std::size_t pos,
271271
std::size_t* charactersRead) {
272272
std::size_t endOfSymbol = source.find_first_of(stopChars, pos);
273273

@@ -280,6 +280,81 @@ class ReadStream : public SerialisationStream<false> {
280280
return source.substr(pos, *charactersRead);
281281
}
282282

283+
std::string readQuotedSymbol(const std::string& source, std::size_t pos, std::size_t* charactersRead) {
284+
const std::size_t start = pos;
285+
const std::size_t end = source.length();
286+
287+
const char quoteMark = source[pos];
288+
++pos;
289+
290+
const std::size_t startOfSymbol = pos;
291+
std::size_t endOfSymbol = std::string::npos;
292+
bool hasEscaped = false;
293+
294+
bool escaped = false;
295+
while (pos < end) {
296+
if (escaped) {
297+
hasEscaped = true;
298+
escaped = false;
299+
++pos;
300+
continue;
301+
}
302+
303+
const char c = source[pos];
304+
if (c == quoteMark) {
305+
endOfSymbol = pos;
306+
++pos;
307+
break;
308+
}
309+
if (c == '\\') {
310+
escaped = true;
311+
}
312+
++pos;
313+
}
314+
315+
if (endOfSymbol == std::string::npos) {
316+
throw std::invalid_argument("Unexpected end of input");
317+
}
318+
319+
*charactersRead = pos - start;
320+
321+
std::size_t lengthOfSymbol = endOfSymbol - startOfSymbol;
322+
323+
// fast handling of symbol without escape sequence
324+
if (!hasEscaped) {
325+
return source.substr(startOfSymbol, lengthOfSymbol);
326+
} else {
327+
// slow handling of symbol with escape sequence
328+
std::string symbol;
329+
symbol.reserve(lengthOfSymbol);
330+
bool escaped = false;
331+
for (std::size_t pos = startOfSymbol; pos < endOfSymbol; ++pos) {
332+
char ch = source[pos];
333+
if (escaped || ch != '\\') {
334+
symbol.push_back(ch);
335+
escaped = false;
336+
} else {
337+
escaped = true;
338+
}
339+
}
340+
return symbol;
341+
}
342+
}
343+
344+
/**
345+
* Read the next symbol.
346+
* It is either a double-quoted symbol with backslash-escaped chars, or the
347+
* longuest sequence that do not contains any of the given stopChars.
348+
* */
349+
std::string readSymbol(const std::string& source, const std::string& stopChars, const std::size_t pos,
350+
std::size_t* charactersRead) {
351+
if (source[pos] == '"') {
352+
return readQuotedSymbol(source, pos, charactersRead);
353+
} else {
354+
return readUntil(source, stopChars, pos, charactersRead);
355+
}
356+
}
357+
283358
/**
284359
* Read past given character, consuming any preceding whitespace.
285360
*/

src/include/souffle/io/ReadStreamCSV.h

+62-4
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ class ReadStreamCSV : public ReadStream {
5353
int size = static_cast<int>(inputMap.size());
5454
inputMap[size] = size;
5555
}
56+
57+
rfc4180 = (getOr(rwOperation, "rfc4180", "false") == "true");
58+
if (rfc4180 && delimiter.find('"') != std::string::npos) {
59+
std::stringstream errorMessage;
60+
errorMessage << "CSV delimiter cannot contain '\"' character when rfc4180 is enabled.";
61+
throw std::invalid_argument(errorMessage.str());
62+
}
5663
}
5764

5865
protected:
@@ -79,11 +86,10 @@ class ReadStreamCSV : public ReadStream {
7986
++lineNumber;
8087

8188
std::size_t start = 0;
82-
std::size_t end = 0;
8389
std::size_t columnsFilled = 0;
8490
for (uint32_t column = 0; columnsFilled < arity; column++) {
8591
std::size_t charactersRead = 0;
86-
std::string element = nextElement(line, start, end);
92+
std::string element = nextElement(line, start);
8793
if (inputMap.count(column) == 0) {
8894
continue;
8995
}
@@ -156,9 +162,60 @@ class ReadStreamCSV : public ReadStream {
156162
return value;
157163
}
158164

159-
std::string nextElement(const std::string& line, std::size_t& start, std::size_t& end) {
165+
std::string nextElement(const std::string& line, std::size_t& start) {
160166
std::string element;
161167

168+
if (rfc4180) {
169+
if (line[start] == '"') {
170+
// quoted field
171+
const std::size_t end = line.length();
172+
std::size_t pos = start + 1;
173+
bool foundEndQuote = false;
174+
while (pos < end) {
175+
char c = line[pos++];
176+
if (c == '"' && (pos < end) && line[pos] == '"') {
177+
// two double-quote => one double-quote
178+
element.push_back('"');
179+
++pos;
180+
} else if (c == '"') {
181+
foundEndQuote = true;
182+
break;
183+
} else {
184+
element.push_back(c);
185+
}
186+
}
187+
188+
if (!foundEndQuote) {
189+
// missing closing quote
190+
std::stringstream errorMessage;
191+
errorMessage << "Unbalanced field quote in line " << lineNumber << "; ";
192+
throw std::invalid_argument(errorMessage.str());
193+
}
194+
195+
// field must be immediately followed by delimiter or end of line
196+
if (pos != line.length()) {
197+
std::size_t nextDelimiter = line.find(delimiter, pos);
198+
if (nextDelimiter != pos) {
199+
std::stringstream errorMessage;
200+
errorMessage << "Separator expected immediately after quoted field in line "
201+
<< lineNumber << "; ";
202+
throw std::invalid_argument(errorMessage.str());
203+
}
204+
}
205+
206+
start = pos + delimiter.size();
207+
return element;
208+
} else {
209+
// non-quoted field, span until next delimiter or end of line
210+
const std::size_t end = std::min(line.find(delimiter, start), line.length());
211+
element = line.substr(start, end - start);
212+
start = end + delimiter.size();
213+
214+
return element;
215+
}
216+
}
217+
218+
std::size_t end = start;
162219
// Handle record/tuple delimiter coincidence.
163220
if (delimiter.find(',') != std::string::npos) {
164221
int record_parens = 0;
@@ -190,7 +247,7 @@ class ReadStreamCSV : public ReadStream {
190247
// Handle the end-of-the-line case where parenthesis are unbalanced.
191248
if (record_parens != 0) {
192249
std::stringstream errorMessage;
193-
errorMessage << "Unbalanced record parenthesis " << lineNumber << "; ";
250+
errorMessage << "Unbalanced record parenthesis in line " << lineNumber << "; ";
194251
throw std::invalid_argument(errorMessage.str());
195252
}
196253
} else {
@@ -238,6 +295,7 @@ class ReadStreamCSV : public ReadStream {
238295
std::istream& file;
239296
std::size_t lineNumber;
240297
std::map<int, int> inputMap;
298+
bool rfc4180;
241299
};
242300

243301
class ReadFileCSV : public ReadStreamCSV {

src/include/souffle/io/WriteStream.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "souffle/utility/json11.h"
2323
#include <cassert>
2424
#include <cstddef>
25+
#include <iomanip>
2526
#include <map>
2627
#include <memory>
2728
#include <ostream>
@@ -76,6 +77,10 @@ class WriteStream : public SerialisationStream<true> {
7677
writeNextTuple(make_span(tuple).data());
7778
}
7879

80+
virtual void outputSymbol(std::ostream& destination, const std::string& value) {
81+
destination << value;
82+
}
83+
7984
void outputRecord(std::ostream& destination, const RamDomain value, const std::string& name) {
8085
auto&& recordInfo = types["records"][name];
8186

@@ -108,7 +113,7 @@ class WriteStream : public SerialisationStream<true> {
108113
case 'i': destination << recordValue; break;
109114
case 'f': destination << ramBitCast<RamFloat>(recordValue); break;
110115
case 'u': destination << ramBitCast<RamUnsigned>(recordValue); break;
111-
case 's': destination << symbolTable.unsafeDecode(recordValue); break;
116+
case 's': outputSymbol(destination, symbolTable.unsafeDecode(recordValue)); break;
112117
case 'r': outputRecord(destination, recordValue, recordType); break;
113118
case '+': outputADT(destination, recordValue, recordType); break;
114119
default: fatal("Unsupported type attribute: `%c`", recordType[0]);
@@ -173,7 +178,7 @@ class WriteStream : public SerialisationStream<true> {
173178
case 'i': destination << branchArgs[i]; break;
174179
case 'f': destination << ramBitCast<RamFloat>(branchArgs[i]); break;
175180
case 'u': destination << ramBitCast<RamUnsigned>(branchArgs[i]); break;
176-
case 's': destination << symbolTable.unsafeDecode(branchArgs[i]); break;
181+
case 's': outputSymbol(destination, symbolTable.unsafeDecode(branchArgs[i])); break;
177182
case 'r': outputRecord(destination, branchArgs[i], argType); break;
178183
case '+': outputADT(destination, branchArgs[i], argType); break;
179184
default: fatal("Unsupported type attribute: `%c`", argType[0]);

src/include/souffle/io/WriteStreamCSV.h

+54-4
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@ class WriteStreamCSV : public WriteStream {
4242
WriteStreamCSV(const std::map<std::string, std::string>& rwOperation, const SymbolTable& symbolTable,
4343
const RecordTable& recordTable)
4444
: WriteStream(rwOperation, symbolTable, recordTable),
45-
delimiter(getOr(rwOperation, "delimiter", "\t")){};
45+
delimiter(getOr(rwOperation, "delimiter", "\t")) {
46+
rfc4180 = (getOr(rwOperation, "rfc4180", "false") == std::string("true"));
47+
};
4648

4749
const std::string delimiter;
4850

51+
bool rfc4180 = false;
52+
4953
void writeNextTupleCSV(std::ostream& destination, const RamDomain* tuple) {
5054
writeNextTupleElement(destination, typeAttributes.at(0), tuple[0]);
5155

@@ -57,14 +61,60 @@ class WriteStreamCSV : public WriteStream {
5761
destination << "\n";
5862
}
5963

64+
virtual void outputSymbol(std::ostream& destination, const std::string& value) {
65+
outputSymbol(destination, value, false);
66+
}
67+
68+
void outputSymbol(std::ostream& destination, const std::string& value, bool fieldValue) {
69+
if (rfc4180) {
70+
if (!fieldValue) {
71+
destination << '"';
72+
}
73+
destination << '"';
74+
75+
const std::size_t end = value.length();
76+
for (std::size_t pos = 0; pos < end; ++pos) {
77+
char ch = value[pos];
78+
if (ch == '"') {
79+
destination << '\\';
80+
destination << '"';
81+
}
82+
destination << ch;
83+
}
84+
85+
if (!fieldValue) {
86+
destination << '"';
87+
}
88+
destination << '"';
89+
} else {
90+
destination << value;
91+
}
92+
}
93+
6094
void writeNextTupleElement(std::ostream& destination, const std::string& type, RamDomain value) {
6195
switch (type[0]) {
62-
case 's': destination << symbolTable.unsafeDecode(value); break;
96+
case 's': outputSymbol(destination, symbolTable.unsafeDecode(value), true); break;
6397
case 'i': destination << value; break;
6498
case 'u': destination << ramBitCast<RamUnsigned>(value); break;
6599
case 'f': destination << ramBitCast<RamFloat>(value); break;
66-
case 'r': outputRecord(destination, value, type); break;
67-
case '+': outputADT(destination, value, type); break;
100+
case 'r':
101+
if (rfc4180) {
102+
destination << '"';
103+
}
104+
outputRecord(destination, value, type);
105+
if (rfc4180) {
106+
destination << '"';
107+
}
108+
break;
109+
case '+':
110+
if (rfc4180) {
111+
destination << '"';
112+
}
113+
outputADT(destination, value, type);
114+
if (rfc4180) {
115+
destination << '"';
116+
}
117+
break;
68118
default: fatal("unsupported type attribute: `%c`", type[0]);
69119
}
70120
}

tests/Makefile.am

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# - https://opensource.org/licenses/UPL
55
# - <souffle root>/licenses/SOUFFLE-UPL.txt
66

7-
SUBDIRS = interface/functors example/graph_coloring
7+
SUBDIRS = interface/functors interface/graph_coloring
88

99
EXTRA_DIST = $(srcdir)/*.at package.m4 $(TESTSUITE) atlocal.in $(srcdir)/swig $(srcdir)/evaluation $(srcdir)/semantic $(srcdir)/syntactic $(srcdir)/interface $(srcdir)/profile $(srcdir)/provenance
1010

0 commit comments

Comments
 (0)