diff --git a/paddle/fluid/framework/ipu/CMakeLists.txt b/paddle/fluid/framework/ipu/CMakeLists.txt index 0521fbde0b4dd9..318d3d68a66a7b 100644 --- a/paddle/fluid/framework/ipu/CMakeLists.txt +++ b/paddle/fluid/framework/ipu/CMakeLists.txt @@ -7,9 +7,9 @@ set(POPART_CANONICALIZATION_SRC "popart_canonicalization/tensor_ops.cc" "popart_canonicalization/other_ops.cc" ) -cc_library(popart_canonicalization_utils SRCS ${POPART_CANONICALIZATION_SRC} DEPS framework_proto enforce) cc_library(ipu_device SRCS device.cc DEPS enforce popart) cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart) cc_library(ipu_build_strategy SRCS ipu_build_strategy.cc DEPS popart graph framework_proto enforce) cc_library(ipu_backend SRCS ipu_backend.cc DEPS popart graph framework_proto enforce ipu_utils ipu_build_strategy ipu_device graph_helper) +cc_library(popart_canonicalization_utils SRCS ${POPART_CANONICALIZATION_SRC} DEPS framework_proto enforce ipu_utils) diff --git a/paddle/fluid/framework/ipu/ipu_backend.cc b/paddle/fluid/framework/ipu/ipu_backend.cc index 0f94888cddc4f5..9d7f6e7d4d7ece 100644 --- a/paddle/fluid/framework/ipu/ipu_backend.cc +++ b/paddle/fluid/framework/ipu/ipu_backend.cc @@ -249,10 +249,12 @@ void IpuBackend::LowerWeights(const ir::Graph* graph) { } void IpuBackend::LowerBody(const ir::Graph* graph) { + VLOG(10) << "enter IpuBackend::LowerBody"; auto nodes = TopologySortOperations(*graph); for (const auto* node : nodes) { auto* op = node->Op(); auto op_type = op->Type(); + VLOG(10) << "Lowering Node: " << node->Name() << " Op: " << op_type; if (op_type == "RandomUniform") { auto outputs = op->Output("__outputs__"); auto shape = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); @@ -271,16 +273,42 @@ void IpuBackend::LowerBody(const ir::Graph* graph) { popart::TensorId result = builder_->aiOnnxOpset11().randomnormal(shape, dtype, mean, scale); tensors_.emplace(outputs[0], result); - } else if (op_type == "ConstantOfShape") { - // TODO(alleng) use RandomUniform for now + } else if (op_type == "Constant") { auto outputs = op->Output("__outputs__"); - auto shape = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); - auto dtype = BOOST_GET_CONST(int, op->GetAttr("dtype")); - auto value = BOOST_GET_CONST(float, op->GetAttr("value")); - auto high = value; - auto low = value; - popart::TensorId result = - builder_->aiOnnxOpset11().randomuniform(shape, dtype, high, low); + auto dims = BOOST_GET_CONST(std::vector, op->GetAttr("dims")); + auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); + auto dtype = OnnxDtype2PopartType(dtype_); + popart::TensorInfo tensor_info{dtype, dims}; + auto value_attr = op->GetAttr("value"); + auto const_data = std::unique_ptr{}; + // Attribute only support vector of int,int64_t,float,double,bool + // if need store other data type, we should try another way + switch (dtype) { + case popart::DataType::FLOAT: + const_data.reset(new popart::ConstVoidData( + BOOST_GET_CONST(std::vector, value_attr).data(), + tensor_info)); + break; + case popart::DataType::INT32: + const_data.reset(new popart::ConstVoidData( + BOOST_GET_CONST(std::vector, value_attr).data(), + tensor_info)); + break; + case popart::DataType::DOUBLE: + const_data.reset(new popart::ConstVoidData( + BOOST_GET_CONST(std::vector, value_attr).data(), + tensor_info)); + break; + case popart::DataType::INT64: + const_data.reset(new popart::ConstVoidData( + BOOST_GET_CONST(std::vector, value_attr).data(), + tensor_info)); + break; + default: + PADDLE_THROW( + platform::errors::Unimplemented("popart::DataType %d", dtype)); + } + popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data); tensors_.emplace(outputs[0], result); } else if (op_type == "Add") { auto inputs = GetOpInputs(op); @@ -310,6 +338,11 @@ void IpuBackend::LowerBody(const ir::Graph* graph) { popart::TensorId result = builder_->aiOnnxOpset11().reducemean(inputs, axes, keepdims); tensors_.emplace(outputs[0], result); + } else if (op_type == "Pow") { + auto inputs = GetOpInputs(op); + auto outputs = op->Output("__outputs__"); + popart::TensorId result = builder_->aiOnnxOpset11().pow(inputs); + tensors_.emplace(outputs[0], result); } else { PADDLE_THROW(platform::errors::Unimplemented("Unimplemented op type %s.", op_type)); diff --git a/paddle/fluid/framework/ipu/ipu_utils.cc b/paddle/fluid/framework/ipu/ipu_utils.cc index 029adb690beffe..cfea2ab89d6124 100644 --- a/paddle/fluid/framework/ipu/ipu_utils.cc +++ b/paddle/fluid/framework/ipu/ipu_utils.cc @@ -32,6 +32,8 @@ popart::DataType VarType2PopartType(proto::VarType::Type type) { return popart::DataType::INT64; case proto::VarType::BOOL: return popart::DataType::BOOL; + case proto::VarType::FP64: + return popart::DataType::DOUBLE; case proto::VarType::FP32: return popart::DataType::FLOAT; case proto::VarType::FP16: @@ -47,6 +49,40 @@ popart::DataType VarType2PopartType(proto::VarType::Type type) { "Unsupported Paddle var type.")); } } + +popart::DataType OnnxDtype2PopartType(int type) { + auto dtype = static_cast(type); + switch (dtype) { + case ONNXDataType::BOOL: + return popart::DataType::BOOL; + case ONNXDataType::INT16: + return popart::DataType::INT16; + case ONNXDataType::INT32: + return popart::DataType::INT32; + case ONNXDataType::INT64: + return popart::DataType::INT64; + case ONNXDataType::FLOAT16: + return popart::DataType::FLOAT16; + case ONNXDataType::FLOAT: + return popart::DataType::FLOAT; + case ONNXDataType::DOUBLE: + return popart::DataType::DOUBLE; + case ONNXDataType::UINT8: + return popart::DataType::UINT8; + case ONNXDataType::INT8: + return popart::DataType::INT8; + case ONNXDataType::BFLOAT16: + return popart::DataType::BFLOAT16; + case ONNXDataType::COMPLEX64: + return popart::DataType::COMPLEX64; + case ONNXDataType::COMPLEX128: + return popart::DataType::COMPLEX128; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported ONNX data type: %d.", dtype)); + } +} + // count num should > 0 bool GetBoolEnv(std::string str) { char *str_val = getenv(str.c_str()); diff --git a/paddle/fluid/framework/ipu/ipu_utils.h b/paddle/fluid/framework/ipu/ipu_utils.h index c8f4785763cacf..f0f5c47c251b52 100644 --- a/paddle/fluid/framework/ipu/ipu_utils.h +++ b/paddle/fluid/framework/ipu/ipu_utils.h @@ -26,7 +26,30 @@ namespace paddle { namespace framework { namespace ipu { +// onnx dtype +// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3 +enum ONNXDataType : int { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16 +}; + popart::DataType VarType2PopartType(proto::VarType::Type type); +popart::DataType OnnxDtype2PopartType(int type); bool GetBoolEnv(std::string str); template diff --git a/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.cc b/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.cc index 1227d177bdb235..c85dbca4ef5ced 100644 --- a/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.cc +++ b/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.cc @@ -18,27 +18,6 @@ namespace paddle { namespace framework { namespace ipu { -// onnx dtype -// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3 -enum ONNXDataType : int { - UNDEFINED = 0, - FLOAT = 1, - UINT8 = 2, - INT8 = 3, - UINT16 = 4, - INT16 = 5, - INT32 = 6, - INT64 = 7, - STRING = 8, - BOOL = 9, - FLOAT16 = 10, - DOUBLE = 11, - UINT32 = 12, - UINT64 = 13, - COMPLEX64 = 14, - COMPLEX128 = 15, - BFLOAT16 = 16 -}; // This avoids the static initialisation order fiasco, std::unordered_map &SymbolHandlers() { diff --git a/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h b/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h index ce792f1262866b..0c647c577b4a55 100644 --- a/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h +++ b/paddle/fluid/framework/ipu/popart_canonicalization/canonicalization_utils.h @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ipu/ipu_utils.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc b/paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc index a8e870985e75fb..3e0d216ab53ece 100644 --- a/paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc +++ b/paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc @@ -62,8 +62,53 @@ ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) { return graph->CreateOpNode(op_desc.get()); } +ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) { + // Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow) + auto *op = node->Op(); + // Op(Constant) + auto op_const = std::make_unique(); + op_const->SetType("Constant"); + std::string op_const_out = op->Output("Out").front() + ":__0_"; + auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor")); + auto value = std::vector{value_}; + op_const->SetAttr("value", value); + auto dims = std::vector{1}; + op_const->SetAttr("dims", dims); + op_const->SetAttr("dtype", ONNXDataType::FLOAT); + std::vector outputs_const; + outputs_const.push_back(op_const_out); + op_const->SetOutput("__outputs__", outputs_const); + op_const->Flush(); + // Var(const_out) + auto var_const = std::make_unique(op_const_out); + var_const->SetType(proto::VarType::LOD_TENSOR); + var_const->SetDataType(proto::VarType::FP32); + auto shape_var_const = std::vector{1}; + var_const->SetShape(shape_var_const); + auto var_node_const = graph->CreateVarNode(var_const.get()); + auto node_const = graph->CreateOpNode(op_const.get()); + MoveNodeInputs(node, node_const); + ConnectNodes(node_const, var_node_const); + // Op(Pow) + auto op_pow = std::make_unique(); + op_pow->SetType("Pow"); + std::vector inputs; + inputs.push_back(op->Input("X").front()); + inputs.push_back(op_const->Output("__outputs__").front()); + op_pow->SetInput("__inputs__", inputs); + std::vector outputs; + outputs.push_back(op->Output("Out").front()); + op_pow->SetOutput("__outputs__", outputs); + op_pow->Flush(); + auto node_pow = graph->CreateOpNode(op_pow.get()); + ConnectNodes(var_node_const, node_pow); + MoveNodeOutputs(node, node_pow); + return node_pow; +} + REGISTER_HANDLER(elementwise_add, elementwise_add_handler); REGISTER_HANDLER(reduce_mean, reduce_mean_handler); +REGISTER_HANDLER(pow, pow_handler); } // namespace } // namespace ipu diff --git a/paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc b/paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc index 35b2398c50372e..e08608c8775f82 100644 --- a/paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc +++ b/paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc @@ -23,18 +23,43 @@ namespace { ir::Node *fill_constant_handler(ir::Graph *graph, ir::Node *node) { auto *op = node->Op(); auto op_desc = std::make_unique(); - op_desc->SetType("ConstantOfShape"); - + op_desc->SetType("Constant"); + if (!op->Input("ShapeTensor").empty()) { + PADDLE_THROW( + platform::errors::Unimplemented("op fill_constant with ShapeTensor")); + } std::vector outputs; outputs.push_back(op->Output("Out").front()); op_desc->SetOutput("__outputs__", outputs); - auto shape = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); - op_desc->SetAttr("shape", shape); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype = ConvertDataType(dtype_); op_desc->SetAttr("dtype", dtype); - auto value = BOOST_GET_CONST(float, op->GetAttr("value")); + auto dims = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + op_desc->SetAttr("dims", dims); + auto value_ = BOOST_GET_CONST(float, op->GetAttr("value")); + size_t size = 1; + for (auto &dim : dims) { + size *= dim; + } + Attribute value; + switch (dtype_) { + case proto::VarType::FP32: + value = std::vector(size, value_); + break; + case proto::VarType::FP64: + value = std::vector(size, value_); + break; + case proto::VarType::INT32: + value = std::vector(size, value_); + break; + case proto::VarType::INT64: + value = std::vector(size, value_); + break; + default: + PADDLE_THROW( + platform::errors::Unimplemented("fill_constant dtype: %d", dtype_)); + } op_desc->SetAttr("value", value); op_desc->Flush(); diff --git a/python/paddle/fluid/tests/unittests/ipu/test_pow_op.py b/python/paddle/fluid/tests/unittests/ipu/test_pow_op.py new file mode 100644 index 00000000000000..b0085e2c7da827 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/test_pow_op.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import paddle +import paddle.fluid +import paddle.fluid.compiler as compiler + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_ipu(), + "core is not compiled with IPU") +class TestPow(unittest.TestCase): + def _test_pow(self, run_ipu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + with paddle.static.program_guard(main_prog, startup_prog): + image = paddle.fluid.layers.fill_constant( + name="a", + shape=[1, 3, 10, 10], + dtype='float32', + value=3.1415926) + out = paddle.pow(image, 2.0) + + if run_ipu: + place = paddle.IPUPlace(0) + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + if run_ipu: + feed_list = [] + fetch_list = [out.name] + ipu_build_strategy = compiler.get_ipu_build_strategy() + ipu_build_strategy.is_training = False + program = compiler.IpuCompiler( + main_prog, ipu_build_strategy=ipu_build_strategy).compile( + feed_list, fetch_list) + else: + program = main_prog + + result = exe.run(program, feed={}, fetch_list=[out]) + return result[0] + + def test_pow(self): + ipu_res = self._test_pow(True) + cpu_res = self._test_pow(False) + + self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4)) + + +if __name__ == "__main__": + unittest.main()