24
24
namespace trtorch {
25
25
namespace core {
26
26
27
- c10::FunctionSchema GenerateGraphSchema (torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
27
+ c10::FunctionSchema GenerateGraphSchema (torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
28
28
29
29
std::vector<c10::Argument> args;
30
30
for (auto in : g->inputs ()) {
31
31
args.push_back (c10::Argument (in->debugName (), in->type ()));
32
32
}
33
-
33
+
34
34
std::vector<c10::Argument> returns;
35
35
for (auto out : g->outputs ()) {
36
36
returns.push_back (c10::Argument (out->debugName (), out->type ()));
37
37
}
38
-
38
+
39
39
return c10::FunctionSchema (method_name, method_name, args, returns);
40
40
}
41
41
42
42
43
43
void AddEngineToGraph (torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
44
- execution::EngineID uid = execution::RegisterEngineFromSerializedEngine (serialized_engine);
44
+ execution::EngineID uid = execution::RegisterEngineFromSerializedEngine (serialized_engine);
45
45
auto schema = execution::GetEngineFunctionSchema (uid);
46
46
auto num_io = execution::GetEngineIO (uid);
47
47
@@ -53,58 +53,42 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
53
53
in_val->setType (c10::TensorType::get ());
54
54
graph_inputs.push_back (in_val);
55
55
}
56
-
56
+
57
57
auto engine_node = g->create (c10::Symbol::fromQualString (schema.name ()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second );
58
58
g->block ()->appendNode (engine_node);
59
59
60
60
for (auto o : engine_node->outputs ()) {
61
61
g->registerOutput (o);
62
62
}
63
-
63
+
64
64
return ;
65
65
}
66
66
67
67
bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod,
68
68
std::string method_name) {
69
- auto g = mod.get_method (method_name).graph ();
70
- // Go through PyTorch Lowering to simplify graph and extract weight parameters
71
- auto graph_and_parameters = torch::jit::LowerGraph (*g, mod._ivalue ());
72
-
73
- g = graph_and_parameters.first ;
74
-
75
- // Go through TRTorch Lowering to reformat graph to be conversion friendly
76
- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
77
- lowering::LowerGraph (g);
78
-
69
+ // Go through Lowering to simplify graph and extract weight parameters
70
+ auto graph_and_parameters = lowering::Lower (mod, method_name);
71
+
72
+ auto g = graph_and_parameters.first ;
79
73
auto params = graph_and_parameters.second ;
80
74
auto named_params = conversion::get_named_params (g->inputs (), params);
81
75
LOG_DEBUG (*g << " (CheckMethodOperatorSupport)\n " );
82
-
83
- // Is this necessary?
84
- lowering::LowerBlock (g->block ());
85
-
76
+
86
77
return conversion::VerifyConverterSupportForBlock (g->block ());
87
78
}
88
79
89
80
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod,
90
81
std::string method_name,
91
82
conversion::ExtraInfo cfg) {
92
- auto g = mod.get_method (method_name).graph ();
93
- // Go through PyTorch Lowering to simplify graph and extract weight parameters
94
- auto graph_and_parameters = torch::jit::LowerGraph (*g, mod._ivalue ());
95
-
96
- g = graph_and_parameters.first ;
97
-
98
- // Go through TRTorch Lowering to reformat graph to be conversion friendly
99
- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
100
- lowering::LowerGraph (g);
101
-
83
+ // Go through Lowering to simplify graph and extract weight parameters
84
+ auto graph_and_parameters = lowering::Lower (mod, method_name);
85
+
86
+ auto g = graph_and_parameters.first ;
102
87
auto params = graph_and_parameters.second ;
103
88
auto named_params = conversion::get_named_params (g->inputs (), params);
89
+
104
90
LOG_INFO (*g << " (CompileGraph)\n " );
105
-
106
- // Is this necessary?
107
- lowering::LowerBlock (g->block ());
91
+
108
92
auto engine = ConvertBlockToEngine (g->block (), cfg, named_params);
109
93
return std::move (engine);
110
94
}
@@ -128,7 +112,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
128
112
129
113
return new_mod;
130
114
}
131
-
115
+
132
116
} // namespace core
133
117
} // namespace trtorch
134
118
0 commit comments