Skip to content

Commit 13eef91

Browse files
committedAug 10, 2021
fix: Address review comments, fix failing tests due to bool mishandling
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 6844a7f commit 13eef91

File tree

5 files changed

+21
-19
lines changed

5 files changed

+21
-19
lines changed
 

‎core/lowering/lowering.cpp

+9-15
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) {
2424
DropUnusedNodes(b);
2525
}
2626

27-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
27+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
2828
passes::UnpackHardSwish(g);
2929
torch::jit::EliminateRedundantGuards(g);
3030
torch::jit::RemoveListMutation(g);
@@ -42,7 +42,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
4242
passes::Conv3DToConvolution(g);
4343
passes::FuseAddMMBranches(g);
4444
passes::RemoveBNDimCheck(g);
45-
if (!disable_cse) {
45+
if (!lower_info.disable_cse) {
4646
torch::jit::EliminateCommonSubexpression(g);
4747
}
4848
// torch::jit::UnrollLoops(g);
@@ -72,25 +72,19 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
7272
auto g = lowered_mod.get_method(method_name).graph();
7373
LOG_GRAPH(*g);
7474

75+
LOG_GRAPH("LibTorch Lowering");
76+
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
77+
7578
// Go through TRTorch Lowering to reformat graph to be conversion friendly
76-
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
79+
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU , PYT)
7780
// unfreeze_module is used to not perform constant folding on weights in the network.
7881
// In quantization aware trained (QAT) models, weights are passed through quantize and
7982
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
80-
if (!lower_info.unfreeze_module) {
81-
LOG_GRAPH("TRTorch Graph Lowering");
82-
lowering::LowerGraph(g, false);
83-
}
83+
LOG_GRAPH("TRTorch Graph Lowering");
84+
lowering::LowerGraph(graph_and_ivalues.first, lower_info);
8485

85-
LOG_GRAPH("LibTorch Lowering");
86-
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
87-
88-
if (lower_info.unfreeze_module) {
89-
LOG_GRAPH("TRTorch Graph Lowering");
90-
lowering::LowerGraph(graph_and_ivalues.first, true);
91-
}
9286
// Is this necessary?
93-
lowering::LowerBlock(g->block());
87+
// lowering::LowerBlock(g->block());
9488

9589
return graph_and_ivalues;
9690
}

‎core/lowering/lowering.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@ namespace lowering {
88

99
struct LowerInfo {
1010
// Internal flag to ensure torch.jit.Module does not get freezed in lowering.cpp. This is required for QAT models.
11-
bool unfreeze_module;
11+
bool unfreeze_module = false;
12+
// CommonSubexpressionElimination removes duplicate expressions which are used frequently in the graph.
13+
// for eg: CSE replaces similar value-d stride nodes of multiple conv layers in a network with a single stride node.
14+
// In QAT models, if two conv layers are consuming same input, there is a QDQ node for each input of the conv.
15+
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
16+
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
17+
bool disable_cse = false;
1218
};
1319

1420
void LowerBlock(torch::jit::Block* b);
15-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse = false);
21+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
1622
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
1723
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
1824
const torch::jit::script::Module& mod,

‎cpp/api/src/compile_spec.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
409409
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
410410
} else {
411411
internal.lower_info.unfreeze_module = true;
412+
internal.lower_info.disable_cse = true;
412413
internal.convert_info.engine_settings.calibrator = nullptr;
413414
}
414415
} else {

‎py/trtorch/csrc/tensorrt_backend.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
2727
mod = core::lowering::LowerModule(mod);
2828

2929
auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
30-
core::CompileSpec cfg({});
30+
lowering::LowerInfo lower_info;
3131
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
3232
const auto& method_name = it->key();
3333
auto method = mod.get_method(method_name);
3434
auto graph = method.graph();
35-
core::lowering::LowerGraph(graph, cfg.lower_info);
35+
core::lowering::LowerGraph(graph, lower_info);
3636
}
3737

3838
auto handles = c10::impl::GenericDict(

‎py/trtorch/csrc/tensorrt_classes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
188188
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
189189
info.convert_info.engine_settings.enabled_precisions.end()) {
190190
info.lower_info.unfreeze_module = true;
191+
info.lower_info.disable_cse = true;
191192
}
192193
}
193194
info.convert_info.engine_settings.sparse_weights = sparse_weights;

0 commit comments

Comments
 (0)
Please sign in to comment.