@@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) {
24
24
DropUnusedNodes (b);
25
25
}
26
26
27
- void LowerGraph (std::shared_ptr<torch::jit::Graph>& g, bool disable_cse ) {
27
+ void LowerGraph (std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info ) {
28
28
passes::UnpackHardSwish (g);
29
29
torch::jit::EliminateRedundantGuards (g);
30
30
torch::jit::RemoveListMutation (g);
@@ -42,7 +42,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
42
42
passes::Conv3DToConvolution (g);
43
43
passes::FuseAddMMBranches (g);
44
44
passes::RemoveBNDimCheck (g);
45
- if (!disable_cse) {
45
+ if (!lower_info. disable_cse ) {
46
46
torch::jit::EliminateCommonSubexpression (g);
47
47
}
48
48
// torch::jit::UnrollLoops(g);
@@ -72,25 +72,19 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
72
72
auto g = lowered_mod.get_method (method_name).graph ();
73
73
LOG_GRAPH (*g);
74
74
75
+ LOG_GRAPH (" LibTorch Lowering" );
76
+ auto graph_and_ivalues = torch::jit::LowerGraph (*g, lowered_mod._ivalue ());
77
+
75
78
// Go through TRTorch Lowering to reformat graph to be conversion friendly
76
- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
79
+ // and also segment for accelerators and executors (TRT-DLA, TRT-GPU , PYT)
77
80
// unfreeze_module is used to not perform constant folding on weights in the network.
78
81
// In quantization aware trained (QAT) models, weights are passed through quantize and
79
82
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
80
- if (!lower_info.unfreeze_module ) {
81
- LOG_GRAPH (" TRTorch Graph Lowering" );
82
- lowering::LowerGraph (g, false );
83
- }
83
+ LOG_GRAPH (" TRTorch Graph Lowering" );
84
+ lowering::LowerGraph (graph_and_ivalues.first , lower_info);
84
85
85
- LOG_GRAPH (" LibTorch Lowering" );
86
- auto graph_and_ivalues = torch::jit::LowerGraph (*g, lowered_mod._ivalue ());
87
-
88
- if (lower_info.unfreeze_module ) {
89
- LOG_GRAPH (" TRTorch Graph Lowering" );
90
- lowering::LowerGraph (graph_and_ivalues.first , true );
91
- }
92
86
// Is this necessary?
93
- lowering::LowerBlock (g->block ());
87
+ // lowering::LowerBlock(g->block());
94
88
95
89
return graph_and_ivalues;
96
90
}
0 commit comments