Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix parse string #57314

Merged
merged 10 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,15 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
}

if (auto s = attr.dyn_cast<StrAttribute>()) {
os << "(String)" << s.AsString();
std::string s_val = s.AsString();
std::string replacement = "\\\"";
std::string search = "\"";
size_t found = s_val.find(search);
while (found != std::string::npos) {
s_val.replace(found, search.length(), replacement);
found = s_val.find(search, found + replacement.length());
}
os << "\"" << s_val << "\"";
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
if (b.data()) {
os << "true";
Expand Down
28 changes: 9 additions & 19 deletions paddle/pir/core/parser/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,14 @@ IrParser::IrParser(IrContext* ctx, std::istream& is) {
builder.reset(new Builder{ctx});
}

Token IrParser::ConsumeToken() {
auto token = lexer->ConsumeToken();
return token;
}
Token IrParser::ConsumeToken() { return lexer->ConsumeToken(); }

std::string IrParser::GetErrorLocationInfo() {
return "The error occurred in line " + std::to_string(lexer->GetLine()) +
", column " + std::to_string(lexer->GetColumn());
}

Token IrParser::PeekToken() {
auto token = lexer->ConsumeToken();
if (token.token_type_ != EOF_) {
lexer->Unget(token.val_.size());
}
return token;
}
Token IrParser::PeekToken() { return lexer->PeekToken(); }

void IrParser::ConsumeAToken(std::string expect_token_val) {
std::string token_val = ConsumeToken().val_;
Expand Down Expand Up @@ -128,14 +119,13 @@ Attribute IrParser::ParseAttribute() {
auto parenthesis_token = ConsumeToken();
if (parenthesis_token.val_ == "true" || parenthesis_token.val_ == "false") {
return builder->bool_attr(parenthesis_token.val_ == "true");
} else if (parenthesis_token.token_type_ == STRING) {
std::string val = parenthesis_token.val_;
val = val.substr(1, val.size() - 2);
return builder->str_attr(val);
}
std::string attribute_type = PeekToken().val_;
if (attribute_type == "String") {
ConsumeAToken("String");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->str_attr(val);
} else if (attribute_type == "Float") {
if (attribute_type == "Float") {
ConsumeAToken("Float");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
Expand Down Expand Up @@ -216,7 +206,7 @@ Operation* IrParser::ParseOperation() {

OpInfo opinfo = ParseOpInfo();

std::vector<Value> inputs = ParseOprandList();
std::vector<Value> inputs = ParseOperandList();

pir::AttributeMap attributeMap = ParseAttributeMap();

Expand Down Expand Up @@ -269,7 +259,7 @@ OpInfo IrParser::ParseOpInfo() {

// OprandList := ValueList
// ValueList := ValueId(,ValueId)*
std::vector<Value> IrParser::ParseOprandList() {
std::vector<Value> IrParser::ParseOperandList() {
ConsumeAToken("(");
std::vector<Value> inputs{};
Token ind_token = ConsumeToken();
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/parser/ir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class IrParser {

std::vector<std::string> ParseOpResultList();

std::vector<Value> ParseOprandList();
std::vector<Value> ParseOperandList();

AttributeMap ParseAttributeMap();

Expand Down
28 changes: 21 additions & 7 deletions paddle/pir/core/parser/lexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Token Lexer::ConsumeToken() {
return *token;
} else if (auto token = LexValueId()) {
return *token;
} else if (auto token = LexOpName()) {
} else if (auto token = LexString()) {
return *token;
} else if (auto token = LexEOF()) {
return *token;
Expand All @@ -33,6 +33,16 @@ Token Lexer::ConsumeToken() {
}
}

Token Lexer::PeekToken() {
auto pos = is.tellg();
auto token = ConsumeToken();
if (is.eof()) {
is.clear();
}
is.seekg(pos);
return token;
}

char Lexer::GetChar() {
char c = is.get();
if (c == '\n') {
Expand Down Expand Up @@ -160,19 +170,23 @@ std::unique_ptr<Token> Lexer::LexEOF() {
}
}

std::unique_ptr<Token> Lexer::LexOpName() {
std::unique_ptr<Token> Lexer::LexString() {
if (is.peek() != '"') {
return nullptr;
}
GetChar();
std::string token_opname = "";
std::string token_val = "";
while (is.peek() != '"') {
token_opname += GetChar();
char c = GetChar();
if (c == '\\' && is.peek() == '\"') {
c = GetChar();
}
token_val += c;
}
GetChar();
std::unique_ptr<Token> opname_token(
new Token{"\"" + token_opname + "\"", OPNAME});
return opname_token;
std::unique_ptr<Token> string_token(
new Token{"\"" + token_val + "\"", STRING});
return string_token;
}

bool Lexer::IsSpace(char c) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/core/parser/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ class Lexer {
explicit Lexer(std::istream& is) : is(is) {}
~Lexer() = default;
Token ConsumeToken();
Token PeekToken();
std::unique_ptr<Token> LexIdentifer();
std::unique_ptr<Token> LexNumberOrArraow();
std::unique_ptr<Token> LexEndTagOrNullVal();
std::unique_ptr<Token> LexValueId();
std::unique_ptr<Token> LexEOF();
std::unique_ptr<Token> LexOpName();
std::unique_ptr<Token> LexString();
char GetChar();
void SkipWhitespace();
bool IsEndTag(char);
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/parser/token.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ enum Token_type {
SDIGIT = 2,
ENDTAG = 3,
VALUEID = 4,
OPNAME = 5,
STRING = 5,
ARRAOW = 6,
NULL_ = 7,
};
Expand Down
43 changes: 35 additions & 8 deletions test/cpp/pir/core/TestParserText.txt
Original file line number Diff line number Diff line change
@@ -1,43 +1,70 @@

//CHECK attribute
(String)sdfgs.sdsd
(Array)[" File \"train.py\", line 225, in <module>",
" main(args)",
" File \"train.py\", line 197, in main",
" lr_scheduler, args.profiler_options)",
" File \"/home/PaddleClas/ppcls/static/program.py\", line 397, in run",
" fetch_list=fetch_list)",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 1440, in run",
" use_prune=use_prune,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 1635, in _run_impl",
" scope,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 801, in get_program_and_executor",
" scope,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 866, in _get_program_and_executor",
" use_fetch_v2=True,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 411, in _add_feed_fetch_ops",
" attrs={'col': i},",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/framework.py\", line 4056, in append_op",
" attrs=kwargs.get(\"attrs\", None),",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/framework.py\", line 2818, in __init__",
" for frame in traceback.extract_stack():"]
//END

//CHECK type
f32
//END

//CHECK type
pd_op.tensor<256xf32>
//END

//CHECK program
{
(%0) = "builtin.get_parameter" () {parameter_name:(String)conv2d_0.w_0} : () -> pd_op.tensor<64x3x7x7xf32>
(%1) = "pd_op.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:(String)data,stop_gradient:(Array)[true]} : () -> pd_op.tensor<-1x3x224x224xf32>
(%2) = "pd_op.conv2d" (%1, %0) {data_format:(String)NCHW,dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:(String)EXPLICIT,paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd_op.tensor<-1x3x224x224xf32>, pd_op.tensor<64x3x7x7xf32>) -> pd_op.tensor<-1x64x112x112xf32>
(%0) = "builtin.get_parameter" () {parameter_name:"conv2d_0.w_0"} : () -> pd_op.tensor<64x3x7x7xf32>
(%1) = "pd_op.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:"data",stop_gradient:(Array)[true]} : () -> pd_op.tensor<-1x3x224x224xf32>
(%2) = "pd_op.conv2d" (%1, %0) {data_format:"NCHW",dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:"EXPLICIT",paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd_op.tensor<-1x3x224x224xf32>, pd_op.tensor<64x3x7x7xf32>) -> pd_op.tensor<-1x64x112x112xf32>
}
//END

//CHECK attribute
(Array)[(pd_op.DataType)bool,(pd_op.DataType)float32,(pd_op.DataType)float64,
(pd_op.DataType)complex64,(pd_op.DataType)complex128,(pd_op.DataType)Undefined,
(pd_op.DataType)Undefined,(pd_op.DataType)Undefined,(pd_op.DataType)Undefined,
(pd_op.DataType)bfloat16,(pd_op.DataType)uint8,(pd_op.DataType)uint32,(pd_op.DataType)int8,
(pd_op.DataType)uint16,(pd_op.DataType)int16,(pd_op.DataType)int32,(pd_op.DataType)uint64,(pd_op.DataType)int64]

//END

//CHECK attribute
(Array)[(pd_op.Place)Place(gpu:0),(pd_op.Place)Place(gpu_pinned),(pd_op.Place)Place(gpu_pinned),
(pd_op.Place)Place(xpu:0),(pd_op.Place)Place(ipu:0),(pd_op.Place)Place(:0),(pd_op.Place)Place(cpu)]

//END

//CHECK attribute
(Array)[(pd_op.DataLayout)NHWC,(pd_op.DataLayout)STRIDED,(pd_op.DataLayout)NCHW,(pd_op.DataLayout)Undefined(AnyLayout),
(pd_op.DataLayout)ONEDNN,(pd_op.DataLayout)SPARSE_COO,(pd_op.DataLayout)SPARSE_CSR,(pd_op.DataLayout)NDHWC,(pd_op.DataLayout)NCDHW,
(pd_op.DataLayout)PSTRING_UNION]
//END

//CHECK attribute
(Array)[(Double)1,(Int64)0,(String)1]
(Array)[(Double)1,(Int64)0,"1"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(Array)[(Double)1,(Int64)0,"1"]
[(Double)1,(Int64)0,"1"]

接下来还需要优化下针对Array的parser,跟String类似,我们通过[]应该就可以识别,感觉没必要加一个Array,比较累赘

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外针对String Attribute补充一些测试用例吧,比如

Suggested change
(Array)[(Double)1,(Int64)0,"1"]
“\“”
"\\""
"\\\""
"\t\r\n\""

之类的

//END

//CHECK type
vec[bf16,f64,b,i8,u8,i16,c64,c128]
//END

//CHECK attribute
(String)1
(Array)["\"","\\"","\\\"","\t\n\r",""]
//END
2 changes: 1 addition & 1 deletion test/cpp/pir/core/add_dialect_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TEST(IrParserTest, AddAttribute) {

std::string op_str =
" (%0) = \"builtin.get_parameter\" () "
"{parameter_name:(String)conv2d_0.w_0,test:(tp.char)a} : () -> "
"{parameter_name:\"conv2d_0.w_0\",test:(tp.char)a} : () -> "
"pd_op.tensor<64x3x7x7xf32>";
std::stringstream ss;
ss << op_str;
Expand Down
57 changes: 51 additions & 6 deletions test/cpp/pir/core/ir_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,58 @@ class ParserTest {
explicit ParserTest(std::ifstream& test_text) : test_text(test_text) {}
TestTask* GetTestTask();
bool ConsumeTestTask(TestTask* test_task, pir::IrContext* ctx);
std::string Peek(const size_t len);
std::string Get(const size_t len);
};

TestTask* ParserTest::GetTestTask() {
while (test_text.peek() == '\n' || test_text.peek() == ' ') {
test_text.get();
}

if (test_text.peek() == EOF) {
return nullptr;
}
std::string test_info;
while (test_text.peek() != '/') {

while (Peek(7) != "//CHECK" && test_text.peek() != EOF) {
test_text.get();
}
while (test_text.peek() != ' ') {

while (test_text.peek() != ' ' && test_text.peek() != EOF) {
test_text.get();
}

test_text.get();

std::string test_type_info;
while (test_text.peek() != '\n') {
while (test_text.peek() != '\n' && test_text.peek() != ' ' &&
test_text.peek() != EOF) {
test_type_info += test_text.get();
}
test_text.get();
while (test_text.peek() != '/' && test_text.peek() != EOF) {

while (test_text.peek() == '\n' || test_text.peek() == ' ') {
test_text.get();
}

std::string test_info;
while (Peek(5) != "//END" && test_text.peek() != EOF) {
test_info += test_text.get();
}

if (Peek(5) != "//END" || test_info.size() == 0) {
return nullptr;
}

Get(5);

if (test_type_info == "attribute") {
return new TestTask(AttributeTest, test_info);
} else if (test_type_info == "type") {
return new TestTask(TypeTest, test_info);
} else if (test_type_info == "program") {
return new TestTask(ProgramTest, test_info);
}

return nullptr;
}

Expand Down Expand Up @@ -135,6 +158,28 @@ bool ParserTest::ConsumeTestTask(TestTask* test_task, pir::IrContext* ctx) {
return true;
}

std::string ParserTest::Peek(const size_t len) {
std::string str;
auto pos = test_text.tellg();
str = Get(len);
if (test_text.eof()) {
test_text.clear();
}
test_text.seekg(pos);
return str;
}

std::string ParserTest::Get(const size_t len) {
std::string str;
for (size_t i = 0; i < len; i++) {
if (test_text.peek() == EOF) {
break;
}
str += test_text.get();
}
return str;
}

TEST(IrParserTest, TestParserByFile) {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
Expand Down