@@ -240,3 +240,29 @@ TEST(Converters, ATenTransposeConvertsCorrectly) {
240
240
241
241
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
242
242
}
243
+
244
+ TEST (Converters, ATenTransposeNegativeConvertsCorrectly) {
245
+ const auto graph = R"IR(
246
+ graph(%x.1 : Tensor):
247
+ %2 : int = prim::Constant[value=-1]()
248
+ %3 : int = prim::Constant[value=-3]()
249
+ %4 : Tensor = aten::transpose(%x.1, %2, %3)
250
+ return (%4))IR" ;
251
+
252
+ auto g = std::make_shared<torch::jit::Graph>();
253
+ torch::jit::parseIR (graph, &*g);
254
+
255
+ auto in = at::randint (0 , 5 , {2 , 3 , 4 , 5 , 6 }, {at::kCUDA });
256
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
257
+
258
+ std::cout << " Running JIT" << std::endl;
259
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
260
+
261
+ std::cout << " Running TRT" << std::endl;
262
+ in = at::clone (in);
263
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
264
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
265
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
266
+
267
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
268
+ }
0 commit comments