Skip to content

Commit fad4a10

Browse files
committedApr 7, 2020
feat(//lowering): centralize lowering and try to use PyTorch Conv2DBN folding
before using the converter Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 4b58d3b commit fad4a10

File tree

3 files changed

+48
-37
lines changed

3 files changed

+48
-37
lines changed
 

‎core/compiler.cpp

+18-34
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,24 @@
2424
namespace trtorch {
2525
namespace core {
2626

27-
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
27+
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
2828

2929
std::vector<c10::Argument> args;
3030
for (auto in : g->inputs()) {
3131
args.push_back(c10::Argument(in->debugName(), in->type()));
3232
}
33-
33+
3434
std::vector<c10::Argument> returns;
3535
for (auto out : g->outputs()) {
3636
returns.push_back(c10::Argument(out->debugName(), out->type()));
3737
}
38-
38+
3939
return c10::FunctionSchema(method_name, method_name, args, returns);
4040
}
4141

4242

4343
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
44-
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
44+
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
4545
auto schema = execution::GetEngineFunctionSchema(uid);
4646
auto num_io = execution::GetEngineIO(uid);
4747

@@ -53,58 +53,42 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
5353
in_val->setType(c10::TensorType::get());
5454
graph_inputs.push_back(in_val);
5555
}
56-
56+
5757
auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second);
5858
g->block()->appendNode(engine_node);
5959

6060
for (auto o : engine_node->outputs()) {
6161
g->registerOutput(o);
6262
}
63-
63+
6464
return;
6565
}
6666

6767
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
6868
std::string method_name) {
69-
auto g = mod.get_method(method_name).graph();
70-
// Go through PyTorch Lowering to simplify graph and extract weight parameters
71-
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
72-
73-
g = graph_and_parameters.first;
74-
75-
// Go through TRTorch Lowering to reformat graph to be conversion friendly
76-
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
77-
lowering::LowerGraph(g);
78-
69+
// Go through Lowering to simplify graph and extract weight parameters
70+
auto graph_and_parameters = lowering::Lower(mod, method_name);
71+
72+
auto g = graph_and_parameters.first;
7973
auto params = graph_and_parameters.second;
8074
auto named_params = conversion::get_named_params(g->inputs(), params);
8175
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
82-
83-
// Is this necessary?
84-
lowering::LowerBlock(g->block());
85-
76+
8677
return conversion::VerifyConverterSupportForBlock(g->block());
8778
}
8879

8980
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
9081
std::string method_name,
9182
conversion::ExtraInfo cfg) {
92-
auto g = mod.get_method(method_name).graph();
93-
// Go through PyTorch Lowering to simplify graph and extract weight parameters
94-
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
95-
96-
g = graph_and_parameters.first;
97-
98-
// Go through TRTorch Lowering to reformat graph to be conversion friendly
99-
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
100-
lowering::LowerGraph(g);
101-
83+
// Go through Lowering to simplify graph and extract weight parameters
84+
auto graph_and_parameters = lowering::Lower(mod, method_name);
85+
86+
auto g = graph_and_parameters.first;
10287
auto params = graph_and_parameters.second;
10388
auto named_params = conversion::get_named_params(g->inputs(), params);
89+
10490
LOG_INFO(*g << "(CompileGraph)\n");
105-
106-
// Is this necessary?
107-
lowering::LowerBlock(g->block());
91+
10892
auto engine = ConvertBlockToEngine(g->block(), cfg, named_params);
10993
return std::move(engine);
11094
}
@@ -128,7 +112,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
128112

129113
return new_mod;
130114
}
131-
115+
132116
} // namespace core
133117
} // namespace trtorch
134118

‎core/lowering/lowering.cpp

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
#include "torch/csrc/jit/passes/fuse_linear.h"
21
#include "torch/csrc/jit/passes/dead_code_elimination.h"
2+
#include "torch/csrc/jit/passes/fuse_linear.h"
3+
#include "torch/csrc/jit/passes/lower_graph.h"
4+
#include "torch/csrc/jit/passes/quantization.h"
35

46
#include "core/lowering/lowering.h"
57
#include "core/lowering/irfusers/irfusers.h"
@@ -22,7 +24,29 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2224
//irfusers::UnpackBatchNorm(g);
2325
//torch::jit::EliminateDeadCode(g);
2426
}
25-
27+
28+
void LowerModule(const torch::jit::script::Module& mod) {
29+
torch::jit::FoldConvBatchNorm2d(mod);
30+
}
31+
32+
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
33+
std::string method_name) {
34+
LowerModule(mod);
35+
auto g = mod.get_method(method_name).graph();
36+
// Go through PyTorch Lowering to simplify graph and extract weight parameters
37+
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
38+
39+
g = graph_and_parameters.first;
40+
41+
// Go through TRTorch Lowering to reformat graph to be conversion friendly
42+
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
43+
lowering::LowerGraph(g);
44+
// Is this necessary?
45+
lowering::LowerBlock(g->block());
46+
return graph_and_parameters;
47+
}
48+
49+
2650
} // namespace lowering
2751
} // namespace core
2852
} // namespace trtorch

‎core/lowering/lowering.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
namespace trtorch {
66
namespace core {
77
namespace lowering {
8-
8+
99
void LowerBlock(torch::jit::Block* b);
1010
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
11+
void LowerModule(const torch::jit::script::Module& mod);
12+
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
13+
std::string method_name);
1114

1215
} // namespace lowering
1316
} // namespace core

0 commit comments

Comments
 (0)
Please sign in to comment.