Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Aug 21, 2022
1 parent 01dbf51 commit af52de9
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions cinn/hlir/op/contrib/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ using common::CINNValuePack;

ir::Tensor Scatter(
const ir::Tensor &A, const ir::Tensor &B, const ir::Tensor &C, const Type &axis, const std::string &name) {
std::string extern_fun_name;
if (target.arch == common::Target::Arch::NVGPU) {
extern_fun_name.assign("cinn_host_find_int_from_start");
} else if (target.arch == common::Target::Arch::X86) {
extern_fun_name.assign("cinn_host_find_int_from_start");
} else {
LOG(FATAL) << "ScatterAssign only support X86 and NVGPU ! Please Check.\n";
}

auto pos_axis = axis;
if (pos_axis < 0) pos_axis += input->shape.size();
std::vector<int> new_axes;
Expand Down

0 comments on commit af52de9

Please sign in to comment.