Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Aug 21, 2022
1 parent b30a7da commit 7fd55f7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
8 changes: 4 additions & 4 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,21 @@ class NetBuilder : public BaseBuilder {

Variable GatherNd(const Variable& x, const Variable& index, const std::vector<int>& axes = {});

Variable Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis = 0);
Variable Scatter(const Variable& src,
const Variable& index,
const std::vector<int>& shape,
const float& default_value = 0,
const int& axis = 0);
Variable Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis = 0);

Variable ScatterNd(const Variable& src,
const Variable& index,
const std::vector<int>& shape,
const float& default_value = 0,
const Variable& out,
const std::vector<int>& axes = {});
Variable ScatterNd(const Variable& src,
const Variable& index,
const Variable& out,
const std::vector<int>& shape,
const float& default_value = 0,
const std::vector<int>& axes = {});

/**
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ core_gather_headers()

gather_srcs(cinnapi_src SRCS
gather.cc
# scatter.cc
scatter.cc
)

cc_test(test_gather SRCS gather_test.cc DEPS cinncore)
Expand Down
52 changes: 50 additions & 2 deletions cinn/hlir/op/contrib/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ ir::Tensor Scatter(
C->shape,
[=](const std::vector<Expr> &indices) {
auto start = Expr(0);
for (auto s : new_B->shape) {
start = common::AutoSimplify(start * s);
for (int i = 0; i < new_B->shape.size() - 1; ++i) {
start = common::AutoSimplify(start * new_B->shape[i]);
}
auto id = lang::CallExtern(extern_fun_name, {B, B->shape[-1], indices[pos_axis], start});

Expand All @@ -85,6 +85,54 @@ ir::Tensor Scatter(
return res;
}

ir::Tensor ScatterNd(
const ir::Tensor &A, const ir::Tensor &B, const ir::Tensor &C, const Type &axes, const std::string &name) {
std::string extern_fun_name;
if (target.arch == common::Target::Arch::NVGPU) {
extern_fun_name.assign("cinn_host_find_int_from_start");
} else if (target.arch == common::Target::Arch::X86) {
extern_fun_name.assign("cinn_host_find_int_from_start");
} else {
LOG(FATAL) << "ScatterAssign only support X86 and NVGPU ! Please Check.\n";
}

std::vector<int> pos_axes;
for (axis : axes) {
if (axis < 0) {
pos_axes.push_back(axis + C->shape.size());
} else {
pos_axes.push_back(axis);
}
}
std::vector<int> new_axes;
for (int i = 0; i < C->shape.size(); ++i) {
if (i != axis) {
new_axes.push_back(i);
}
}
new_axes.insert(new_axes.end(), axes.begin(), axes.end()) auto new_B =
pe::Transpose(B, new_axes, name + "_index_transpose");

auto res = Compute(
C->shape,
[=](const std::vector<Expr> &indices) {
auto start = Expr(0);
for (int i = 0; i < new_B->shape.size() - axes.size(); ++i) {
start = common::AutoSimplify(start * new_B->shape[i]);
}

auto id = lang::CallExtern(extern_fun_name, {B, B->shape[-1], indices[pos_axis], start});

std::vector<Expr> src_indices(indices);
src_indices[pos_axis] = id;

auto update = ir::EQ::Make(id, Expr(-1));
return ir::Select::Make(update, C(indice), A(src_indices));
},
name);
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForScatter(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
Expand Down

0 comments on commit 7fd55f7

Please sign in to comment.