Skip to content

Commit

Permalink
fix constant op, add pow_op(1-to-N) (PaddlePaddle#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 authored Aug 11, 2021
1 parent 59cf28b commit 42b6581
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 36 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ipu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ set(POPART_CANONICALIZATION_SRC
"popart_canonicalization/tensor_ops.cc"
"popart_canonicalization/other_ops.cc"
)
cc_library(popart_canonicalization_utils SRCS ${POPART_CANONICALIZATION_SRC} DEPS framework_proto enforce)

cc_library(ipu_device SRCS device.cc DEPS enforce popart)
cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart)
cc_library(ipu_build_strategy SRCS ipu_build_strategy.cc DEPS popart graph framework_proto enforce)
cc_library(ipu_backend SRCS ipu_backend.cc DEPS popart graph framework_proto enforce ipu_utils ipu_build_strategy ipu_device graph_helper)
cc_library(popart_canonicalization_utils SRCS ${POPART_CANONICALIZATION_SRC} DEPS framework_proto enforce ipu_utils)
51 changes: 42 additions & 9 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,12 @@ void IpuBackend::LowerWeights(const ir::Graph* graph) {
}

void IpuBackend::LowerBody(const ir::Graph* graph) {
VLOG(10) << "enter IpuBackend::LowerBody";
auto nodes = TopologySortOperations(*graph);
for (const auto* node : nodes) {
auto* op = node->Op();
auto op_type = op->Type();
VLOG(10) << "Lowering Node: " << node->Name() << " Op: " << op_type;
if (op_type == "RandomUniform") {
auto outputs = op->Output("__outputs__");
auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
Expand All @@ -271,16 +273,42 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
popart::TensorId result =
builder_->aiOnnxOpset11().randomnormal(shape, dtype, mean, scale);
tensors_.emplace(outputs[0], result);
} else if (op_type == "ConstantOfShape") {
// TODO(alleng) use RandomUniform for now
} else if (op_type == "Constant") {
auto outputs = op->Output("__outputs__");
auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
auto dtype = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
auto high = value;
auto low = value;
popart::TensorId result =
builder_->aiOnnxOpset11().randomuniform(shape, dtype, high, low);
auto dims = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("dims"));
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = OnnxDtype2PopartType(dtype_);
popart::TensorInfo tensor_info{dtype, dims};
auto value_attr = op->GetAttr("value");
auto const_data = std::unique_ptr<popart::ConstVoidData>{};
// Attribute only support vector of int,int64_t,float,double,bool
// if need store other data type, we should try another way
switch (dtype) {
case popart::DataType::FLOAT:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<float>, value_attr).data(),
tensor_info));
break;
case popart::DataType::INT32:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<int>, value_attr).data(),
tensor_info));
break;
case popart::DataType::DOUBLE:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<double>, value_attr).data(),
tensor_info));
break;
case popart::DataType::INT64:
const_data.reset(new popart::ConstVoidData(
BOOST_GET_CONST(std::vector<int64_t>, value_attr).data(),
tensor_info));
break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("popart::DataType %d", dtype));
}
popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Add") {
auto inputs = GetOpInputs(op);
Expand Down Expand Up @@ -310,6 +338,11 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
popart::TensorId result =
builder_->aiOnnxOpset11().reducemean(inputs, axes, keepdims);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Pow") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().pow(inputs);
tensors_.emplace(outputs[0], result);
} else {
PADDLE_THROW(platform::errors::Unimplemented("Unimplemented op type %s.",
op_type));
Expand Down
36 changes: 36 additions & 0 deletions paddle/fluid/framework/ipu/ipu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ popart::DataType VarType2PopartType(proto::VarType::Type type) {
return popart::DataType::INT64;
case proto::VarType::BOOL:
return popart::DataType::BOOL;
case proto::VarType::FP64:
return popart::DataType::DOUBLE;
case proto::VarType::FP32:
return popart::DataType::FLOAT;
case proto::VarType::FP16:
Expand All @@ -47,6 +49,40 @@ popart::DataType VarType2PopartType(proto::VarType::Type type) {
"Unsupported Paddle var type."));
}
}

popart::DataType OnnxDtype2PopartType(int type) {
auto dtype = static_cast<ONNXDataType>(type);
switch (dtype) {
case ONNXDataType::BOOL:
return popart::DataType::BOOL;
case ONNXDataType::INT16:
return popart::DataType::INT16;
case ONNXDataType::INT32:
return popart::DataType::INT32;
case ONNXDataType::INT64:
return popart::DataType::INT64;
case ONNXDataType::FLOAT16:
return popart::DataType::FLOAT16;
case ONNXDataType::FLOAT:
return popart::DataType::FLOAT;
case ONNXDataType::DOUBLE:
return popart::DataType::DOUBLE;
case ONNXDataType::UINT8:
return popart::DataType::UINT8;
case ONNXDataType::INT8:
return popart::DataType::INT8;
case ONNXDataType::BFLOAT16:
return popart::DataType::BFLOAT16;
case ONNXDataType::COMPLEX64:
return popart::DataType::COMPLEX64;
case ONNXDataType::COMPLEX128:
return popart::DataType::COMPLEX128;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported ONNX data type: %d.", dtype));
}
}

// count num should > 0
bool GetBoolEnv(std::string str) {
char *str_val = getenv(str.c_str());
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/framework/ipu/ipu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,30 @@ namespace paddle {
namespace framework {
namespace ipu {

// onnx dtype
// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
enum ONNXDataType : int {
UNDEFINED = 0,
FLOAT = 1,
UINT8 = 2,
INT8 = 3,
UINT16 = 4,
INT16 = 5,
INT32 = 6,
INT64 = 7,
STRING = 8,
BOOL = 9,
FLOAT16 = 10,
DOUBLE = 11,
UINT32 = 12,
UINT64 = 13,
COMPLEX64 = 14,
COMPLEX128 = 15,
BFLOAT16 = 16
};

popart::DataType VarType2PopartType(proto::VarType::Type type);
popart::DataType OnnxDtype2PopartType(int type);
bool GetBoolEnv(std::string str);

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,6 @@ namespace paddle {
namespace framework {
namespace ipu {

// onnx dtype
// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
enum ONNXDataType : int {
UNDEFINED = 0,
FLOAT = 1,
UINT8 = 2,
INT8 = 3,
UINT16 = 4,
INT16 = 5,
INT32 = 6,
INT64 = 7,
STRING = 8,
BOOL = 9,
FLOAT16 = 10,
DOUBLE = 11,
UINT32 = 12,
UINT64 = 13,
COMPLEX64 = 14,
COMPLEX128 = 15,
BFLOAT16 = 16
};

// This avoids the static initialisation order fiasco,
std::unordered_map<std::string, SymbolHandler> &SymbolHandlers() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ipu/ipu_utils.h"

namespace paddle {
namespace framework {
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,53 @@ ir::Node *reduce_mean_handler(ir::Graph *graph, ir::Node *node) {
return graph->CreateOpNode(op_desc.get());
}

ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto *op = node->Op();
// Op(Constant)
auto op_const = std::make_unique<framework::OpDesc>();
op_const->SetType("Constant");
std::string op_const_out = op->Output("Out").front() + ":__0_";
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
auto value = std::vector<float>{value_};
op_const->SetAttr("value", value);
auto dims = std::vector<int64_t>{1};
op_const->SetAttr("dims", dims);
op_const->SetAttr("dtype", ONNXDataType::FLOAT);
std::vector<std::string> outputs_const;
outputs_const.push_back(op_const_out);
op_const->SetOutput("__outputs__", outputs_const);
op_const->Flush();
// Var(const_out)
auto var_const = std::make_unique<framework::VarDesc>(op_const_out);
var_const->SetType(proto::VarType::LOD_TENSOR);
var_const->SetDataType(proto::VarType::FP32);
auto shape_var_const = std::vector<int64_t>{1};
var_const->SetShape(shape_var_const);
auto var_node_const = graph->CreateVarNode(var_const.get());
auto node_const = graph->CreateOpNode(op_const.get());
MoveNodeInputs(node, node_const);
ConnectNodes(node_const, var_node_const);
// Op(Pow)
auto op_pow = std::make_unique<framework::OpDesc>();
op_pow->SetType("Pow");
std::vector<std::string> inputs;
inputs.push_back(op->Input("X").front());
inputs.push_back(op_const->Output("__outputs__").front());
op_pow->SetInput("__inputs__", inputs);
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_pow->SetOutput("__outputs__", outputs);
op_pow->Flush();
auto node_pow = graph->CreateOpNode(op_pow.get());
ConnectNodes(var_node_const, node_pow);
MoveNodeOutputs(node, node_pow);
return node_pow;
}

REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
REGISTER_HANDLER(pow, pow_handler);

} // namespace
} // namespace ipu
Expand Down
35 changes: 30 additions & 5 deletions paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,43 @@ namespace {
ir::Node *fill_constant_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
op_desc->SetType("ConstantOfShape");

op_desc->SetType("Constant");
if (!op->Input("ShapeTensor").empty()) {
PADDLE_THROW(
platform::errors::Unimplemented("op fill_constant with ShapeTensor"));
}
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_desc->SetOutput("__outputs__", outputs);

auto shape = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
op_desc->SetAttr("shape", shape);
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = ConvertDataType(dtype_);
op_desc->SetAttr("dtype", dtype);
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
auto dims = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
op_desc->SetAttr("dims", dims);
auto value_ = BOOST_GET_CONST(float, op->GetAttr("value"));
size_t size = 1;
for (auto &dim : dims) {
size *= dim;
}
Attribute value;
switch (dtype_) {
case proto::VarType::FP32:
value = std::vector<float>(size, value_);
break;
case proto::VarType::FP64:
value = std::vector<double>(size, value_);
break;
case proto::VarType::INT32:
value = std::vector<int>(size, value_);
break;
case proto::VarType::INT64:
value = std::vector<int64_t>(size, value_);
break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("fill_constant dtype: %d", dtype_));
}
op_desc->SetAttr("value", value);

op_desc->Flush();
Expand Down
74 changes: 74 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_pow_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestPow(unittest.TestCase):
def _test_pow(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.fluid.layers.fill_constant(
name="a",
shape=[1, 3, 10, 10],
dtype='float32',
value=3.1415926)
out = paddle.pow(image, 2.0)

if run_ipu:
place = paddle.IPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = []
fetch_list = [out.name]
ipu_build_strategy = compiler.get_ipu_build_strategy()
ipu_build_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_build_strategy=ipu_build_strategy).compile(
feed_list, fetch_list)
else:
program = main_prog

result = exe.run(program, feed={}, fetch_list=[out])
return result[0]

def test_pow(self):
ipu_res = self._test_pow(True)
cpu_res = self._test_pow(False)

self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))


if __name__ == "__main__":
unittest.main()

0 comments on commit 42b6581

Please sign in to comment.