Skip to content

Commit

Permalink
polish code & add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Oct 23, 2020
1 parent dd2562c commit b399823
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 35 deletions.
17 changes: 2 additions & 15 deletions paddle/fluid/framework/op_call_stack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@ std::string InsertIndentationIntoEachLine(const std::string &str) {
return sout.str();
}

std::string SimplifyErrorTypeFormat(const std::string &str) {
std::ostringstream sout;
size_t type_end_pos = str.find(":", 0);
if (type_end_pos == std::string::npos) {
sout << str;
} else {
// Remove "Error:", add "()""
sout << "(" << str.substr(0, type_end_pos - 5) << ")"
<< str.substr(type_end_pos + 1);
}
return sout.str();
}

void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
platform::EnforceNotMet *exception) {
if (attrs.count("sub_block") != 0) {
Expand Down Expand Up @@ -78,9 +65,9 @@ void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
// If callstack exists, use err_str_ instead sub_err_str_
if (callstack) {
sout << "\n\n";
sout << InsertIndentationIntoEachLine(exception->what());
sout << InsertIndentationIntoEachLine(exception->error_str());
} else {
sout << SimplifyErrorTypeFormat(exception->what());
sout << exception->simple_error_str();
}
}
sout << " [operator < " << type << " > error]";
Expand Down
62 changes: 50 additions & 12 deletions paddle/fluid/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ limitations under the License. */

#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
Expand Down Expand Up @@ -296,15 +295,28 @@ inline std::string GetTraceBackString(StrType&& what, const char* file,
}
}

inline std::string SimplifyErrorTypeFormat(const std::string& str) {
std::ostringstream sout;
size_t type_end_pos = str.find(":", 0);
if (type_end_pos == std::string::npos) {
sout << str;
} else {
// Remove "Error:", add "()""
sout << "(" << str.substr(0, type_end_pos - 5) << ")"
<< str.substr(type_end_pos + 1);
}
return sout.str();
}

inline bool is_error(bool stat) { return !stat; }

// Note: This Macro can only be used within enforce.h
#define __THROW_ERROR_INTERNAL__(ERROR_SUMMARY) \
do { \
HANDLE_THE_ERROR \
throw ::paddle::platform::EnforceNotMet(ERROR_SUMMARY, __FILE__, \
__LINE__); \
END_HANDLE_THE_ERROR \
#define __THROW_ERROR_INTERNAL__(__ERROR_SUMMARY) \
do { \
HANDLE_THE_ERROR \
throw ::paddle::platform::EnforceNotMet(__ERROR_SUMMARY, __FILE__, \
__LINE__); \
END_HANDLE_THE_ERROR \
} while (0)

/** ENFORCE EXCEPTION AND MACROS **/
Expand All @@ -317,29 +329,55 @@ struct EnforceNotMet : public std::exception {
} catch (platform::EnforceNotMet& e) {
code_ = e.code();
err_str_ = GetTraceBackString(e.what(), file, line);
simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
} catch (std::exception& e) {
err_str_ = GetTraceBackString(e.what(), file, line);
simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
}
}

EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}
: err_str_(GetTraceBackString(str, file, line)) {
simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
}

EnforceNotMet(const ErrorSummary& error, const char* file, int line)
: code_(error.code()),
err_str_(GetTraceBackString(error.to_string(), file, line)) {}
err_str_(GetTraceBackString(error.to_string(), file, line)) {
simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
}

const char* what() const noexcept override { return err_str_.c_str(); }
const char* what() const noexcept override {
if (FLAGS_call_stack_level > 1) {
return err_str_.c_str();
} else {
return simple_err_str_.c_str();
}
}

error::Code code() const { return code_; }

void set_error_str(std::string str) { err_str_ = str; }
const std::string& error_str() const { return err_str_; }

const std::string& simple_error_str() const { return simple_err_str_; }

void set_error_str(std::string str) {
if (FLAGS_call_stack_level > 1) {
err_str_ = str;
} else {
simple_err_str_ = str;
}
}

private:
// Used to determine the final type of exception thrown
error::Code code_ = error::LEGACY;
// Current error message
// Complete error message
// e.g. InvalidArgumentError: ***
std::string err_str_;
// Simple errror message used when no C++ stack and python compile stack
// e.g. (InvalidArgument) ***
std::string simple_err_str_;
};

#define PADDLE_THROW(...) \
Expand Down
8 changes: 2 additions & 6 deletions paddle/fluid/pybind/exception.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
#include <utility>

#include "paddle/fluid/pybind/exception.h"

namespace paddle {
Expand All @@ -37,8 +34,7 @@ namespace pybind {

void BindException(pybind11::module* m) {
static pybind11::exception<platform::EOFException> eof(*m, "EOFException");
static pybind11::exception<platform::EnforceNotMet> ex_base(*m,
"EnforceNotMet");
static pybind11::exception<platform::EnforceNotMet> exc(*m, "EnforceNotMet");
pybind11::register_exception_translator([](std::exception_ptr p) {
try {
if (p) std::rethrow_exception(p);
Expand Down Expand Up @@ -73,7 +69,7 @@ void BindException(pybind11::module* m) {
PyErr_SetString(PyExc_OSError, e.what());
break;
default:
ex_base(e.what());
exc(e.what());
break;
}
}
Expand Down
13 changes: 11 additions & 2 deletions python/paddle/fluid/tests/unittests/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def test_exception(self):
self.assertIsNotNone(exception)


class TestExceptionStatic(unittest.TestCase):
class TestExceptionNoCStack(unittest.TestCase):
def setUp(self):
paddle.enable_static()
# test no C++ stack format
fluid.set_flags({'FLAGS_call_stack_level': 1})

def test_exception_in_static_model(self):
def test_exception_in_static_mode(self):
x = fluid.layers.data(name='X', shape=[-1, 13], dtype='float32')
y = fluid.layers.data(name='Y', shape=[-1, 1], dtype='float32')
predict = fluid.layers.fc(input=x, size=1, act=None)
Expand All @@ -64,6 +64,15 @@ def test_exception_in_static_model(self):
'Y': y},
fetch_list=[avg_loss.name])

def test_exception_in_dynamic_mode(self):
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
x = numpy.random.random(size=(10, 2)).astype('float32')
linear = fluid.dygraph.Linear(1, 10)
data = fluid.dygraph.to_variable(x)
with self.assertRaises(ValueError):
res = linear(data)


if __name__ == "__main__":
unittest.main()

1 comment on commit b399823

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.