Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 10, 2022
1 parent e714d03 commit 6d05ac3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
4 changes: 2 additions & 2 deletions cinn/hlir/op/contrib/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(const framework::NodeAttr
auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
auto tensor_name = UniqName("Sort_out");
auto tensor_name = UniqName("Sort_out");
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 2U);
CHECK(pack_args[1].is_string());
Expand Down Expand Up @@ -224,7 +224,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(const framework::NodeA
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = ArgSort(tensor_A, target, axis, is_ascend, tensor_name);
ir::Tensor out = ArgSort(tensor_A, target, axis, is_ascend, tensor_name);
std::vector<CINNValue> res;
stages->InsertLazily(out);
res.push_back(CINNValue(out));
Expand Down
23 changes: 18 additions & 5 deletions cinn/runtime/cpu/host_intrinsics_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ TEST(tanh, basic) {
TEST(find_value_nd, basic) {
Expr M(10), N(20);
Placeholder<float> x("x", {M, N});
auto y = Compute({N}, [&](Expr i) {
return CallExtern("cinn_host_find_float_nd", {x, M, x({Expr(5), Expr(3)}), i, N});
});
auto y = Compute(
{N},
[&](Expr i) {
return CallExtern("cinn_host_find_float_nd", {x, M, x({Expr(5), Expr(3)}), i, N});
},
"y");

auto stages = CreateStages({y});

Expand Down Expand Up @@ -109,7 +112,12 @@ TEST(find_value_nd, basic) {
TEST(cinn_host_lt_num_float, basic) {
Expr M(10), N(20);
Placeholder<float> x("x", {M, N});
auto y = Compute({N}, [&](Expr j) { return CallExtern("cinn_host_lt_num_float", {x, M, x({Expr(0), j}), j, N}); });
auto y = Compute(
{N},
[&](Expr j) {
return CallExtern("cinn_host_lt_num_float", {x, M, x({Expr(0), j}), j, N});
},
"y");

auto stages = CreateStages({y});

Expand Down Expand Up @@ -151,7 +159,12 @@ TEST(cinn_host_lt_num_float, basic) {
TEST(cinn_host_gt_num_float, basic) {
Expr M(10), N(20);
Placeholder<float> x("x", {M, N});
auto y = Compute({N}, [&](Expr j) { return CallExtern("cinn_host_gt_num_float", {x, M, x({Expr(0), j}), j, N}); });
auto y = Compute(
{N},
[&](Expr j) {
return CallExtern("cinn_host_gt_num_float", {x, M, x({Expr(0), j}), j, N});
},
"y");

auto stages = CreateStages({y});

Expand Down

0 comments on commit 6d05ac3

Please sign in to comment.