Skip to content

Commit 4a1d2f3

Browse files
committedMar 1, 2021
feat: support aten::transpose with negative dim
Signed-off-by: inocsin <vcheungyi@163.com>
1 parent b1b5f19 commit 4a1d2f3

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed
 

‎core/conversion/converters/impl/shuffle.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ static auto shuffle_registrations TRTORCH_UNUSED =
106106
for (size_t i = 0; i < ndims; i++) {
107107
new_order.push_back(i);
108108
}
109+
dim0 = dim0 < 0 ? (dim0 + ndims) : dim0;
110+
dim1 = dim1 < 0 ? (dim1 + ndims) : dim1;
109111
auto tmp = dim0;
110112
new_order[dim0] = new_order[dim1];
111113
new_order[dim1] = tmp;

‎tests/core/conversion/converters/test_shuffle.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,29 @@ TEST(Converters, ATenTransposeConvertsCorrectly) {
240240

241241
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
242242
}
243+
244+
TEST(Converters, ATenTransposeNegativeConvertsCorrectly) {
245+
const auto graph = R"IR(
246+
graph(%x.1 : Tensor):
247+
%2 : int = prim::Constant[value=-1]()
248+
%3 : int = prim::Constant[value=-3]()
249+
%4 : Tensor = aten::transpose(%x.1, %2, %3)
250+
return (%4))IR";
251+
252+
auto g = std::make_shared<torch::jit::Graph>();
253+
torch::jit::parseIR(graph, &*g);
254+
255+
auto in = at::randint(0, 5, {2, 3, 4, 5, 6}, {at::kCUDA});
256+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
257+
258+
std::cout << "Running JIT" << std::endl;
259+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
260+
261+
std::cout << "Running TRT" << std::endl;
262+
in = at::clone(in);
263+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
264+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
265+
auto trt = trt_results[0].reshape_as(jit_results[0]);
266+
267+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
268+
}

0 commit comments

Comments
 (0)
Please sign in to comment.