Skip to content

Commit 8dc1a06

Browse files
authored
fix: Repair Citrinet-1024 compilation issues [Duplicate of PR #1488 for Release 1.3] (#1489)
1 parent 8d7cd50 commit 8dc1a06

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

core/conversion/converters/impl/element_wise.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ auto element_wise_registrations TORCHTRT_UNUSED =
325325
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
326326
} else if (rounding_mode == "trunc") {
327327
// trunc = floor(abs(div)) * sign(div)
328-
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
328+
auto tmp_div = add_elementwise(
329+
ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_tmp_div");
329330
auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");
330331

331332
// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this

core/conversion/converters/impl/reduce.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ auto reduce_registrations TORCHTRT_UNUSED =
113113
LOG_DEBUG("Keep dims: " << keepdim);
114114

115115
LOG_WARNING("Sum converter disregards dtype");
116+
117+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
118+
LOG_DEBUG(
119+
"Found type " << in_tensor->getType() << " in aten::sum, casting to "
120+
<< nvinfer1::DataType::kINT32 << " for compatibility.");
121+
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32);
122+
}
123+
116124
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
117125

118126
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);

tests/core/conversion/converters/test_reduce.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
8+
#include "torch/torch.h"
89

910
namespace {
1011
std::string gen_basic_graph(const std::string& op) {
@@ -162,6 +163,19 @@ TEST(Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) {
162163
test_body(graph, in);
163164
}
164165

166+
TEST(Converters, ATenSumDimNegOneIndexKeepDimsBoolTensorConvertsCorrectly) {
167+
const auto graph = R"IR(
168+
graph(%0 : Tensor):
169+
%1 : int = prim::Constant[value=-1]()
170+
%2 : int[] = prim::ListConstruct(%1)
171+
%3 : bool = prim::Constant[value=1]()
172+
%4 : None = prim::Constant()
173+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
174+
return (%5))IR";
175+
auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA).to(torch::kBool);
176+
test_body(graph, in);
177+
}
178+
165179
TEST(Converters, ATenSumDimNegIndexConvertsCorrectly) {
166180
const auto graph = R"IR(
167181
graph(%0 : Tensor):

0 commit comments

Comments
 (0)