Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CinnCompiler class for compiling subgraphs found by build_cinn_pass. #36562

Merged
merged 16 commits into from
Oct 25, 2021

Conversation

wzzju
Copy link
Contributor

@wzzju wzzju commented Oct 19, 2021

PR types

New features

PR changes

Others

Describe

Creating the CinnCompiler class for compiling subgraphs found by build_cinn_pass, of which the process is described as follows:

  1. build_cinn_pass finds all subgraphs that can be compiled with CINN, and calls the AddGraph API of CinnCompiler to add them to CinnCompiler.

  2. When CinnLaunchOp is executed, it will find and compile its corresponding subgraph through the Compile API of CinnCompiler. After that, the CinnLaunchOp will get a CinnCompiledObject, which is used to perform the actual computation.

For more details, please refer to here:

auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key);
// viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
EXPECT_THROW(cinn_compiler->FindGraph("no_existed"),
paddle::platform::EnforceNotMet);
LoDTensor tensor1, tensor2, tensor3;
tensor1.Resize({1000, 784});
tensor2.Resize({784, 100});
tensor3.Resize({100});
tensor1.mutable_data<float>(platform::CPUPlace());
tensor2.mutable_data<float>(platform::CPUPlace());
tensor3.mutable_data<float>(platform::CPUPlace());
std::map<std::string, const LoDTensor*> input_tensors = {
{"X", &tensor1}, {"Y", &tensor2}, {"Z", &tensor3}};
auto compile_fn = [&](const Target& target) {
const auto& compiled_obj =
cinn_compiler->Compile(compiling_graph, input_tensors, target);
ASSERT_NE(compiled_obj.runtime_program, nullptr);
ASSERT_NE(compiled_obj.scope, nullptr);
ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty());
const auto& cached_obj =
cinn_compiler->Compile(compilation_key, input_tensors, target);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&cached_obj));
};

The execution result of cinn_compiler_test is as follows:

test 149
    Start 149: cinn_compiler_test

149: Test command: /work/Paddle/build/paddle/fluid/framework/paddle2cinn/cinn_compiler_test
149: [==========] Running 1 test from 1 test case.
149: [----------] Global test environment set-up.
149: [----------] 1 test from CinnCompilerTest
149: [ RUN      ] CinnCompilerTest.Compile
149: WARNING: Logging before InitGoogleLogging() is written to STDERR
149: I1024 05:57:31.278139 24525 compiler.cc:73] [CUDA] host module:
149: Module module_host {
149: 
149: function fn_mul_0 (args__ptr, num_args)
149: {
149:   fn_mul_0_kernel(args__ptr, num_args)
149: }
149: function fn_const_scalar_2_broadcast_to_3_elementwise_add_1_max_4_fused (args__ptr, num_args)
149: {
149:   fn_const_scalar_2_broadcast_to_3_elementwise_add_1_max_4_fused_kernel(args__ptr, num_args)
149: }
149: 
149: 
149: }
149: I1024 05:57:31.278211 24525 compiler.cc:76] [CUDA] device module:
149: Module module_gpu_device {
149: 
149: function fn_mul_0_kernel (_X, _Y, _Mul_output)
149: {
149:   if ((blockIdx.x < 1000)) {
149:     if ((threadIdx.x < 50)) {
149:       for (j_inner, 0, 2)
149:       {
149:         Mul_output__reduce_init[blockIdx.x, ((2 * threadIdx.x) + j_inner)] = 0
149:         for (reduce_k, 0, 784)
149:         {
149:           Mul_output[blockIdx.x, ((2 * threadIdx.x) + j_inner)] = (Mul_output[blockIdx.x, ((2 * threadIdx.x) + j_inner)] + (X_reshape[blockIdx.x, reduce_k] * Y_reshape[((2 * threadIdx.x) + j_inner), reduce_k]))
149:         }
149:       }
149:     }
149:   }
149: }
149: function fn_const_scalar_2_broadcast_to_3_elementwise_add_1_max_4_fused_kernel (_var_3, _Z, _max_Out)
149: {
149:   if ((blockIdx.x < 98)) {
149:     if ((threadIdx.x < cinn_min(1024, (100000 + (-1024 * blockIdx.x))))) {
149:       max_Out[((((24 * blockIdx.x) + threadIdx.x) / 100) + (10 * blockIdx.x)), (((24 * blockIdx.x) + threadIdx.x) % 100)] = cinn_max((var_3[((((24 * blockIdx.x) + threadIdx.x) / 100) + (10 * blockIdx.x)), (((24 * blockIdx.x) + threadIdx.x) % 100)] + Z[(((24 * blockIdx.x) + threadIdx.x) % 100)]), 0)
149:     }
149:   }
149: }
149: 
149: 
149: }
149: I1024 05:57:31.309249 24525 compiler.cc:80] [CUDA] source code:
149: extern "C" {
149: 
149: #include "cinn_cuda_runtime_source.cuh"
149: 
149: #ifdef __CUDACC_RTC__
149: typedef int int32_t;
149: typedef char int8_t;
149: #endif
149: 
149: 
149: 
149: __global__
149: void fn_mul_0_kernel(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ Mul_output)
149: {
149:   float* Mul_output__reduce_init = Mul_output;
149:   const float* X_reshape = X;
149:   const float* Y_reshape = Y;
149:   if ((blockIdx.x < 1000)) {
149:     if ((threadIdx.x < 50)) {
149:       for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
149:         Mul_output__reduce_init[((100 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] = 0;
149:         for (int32_t reduce_k = 0; reduce_k < 784; reduce_k += 1) {
149:           Mul_output[((100 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] = (Mul_output[((100 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] + (X_reshape[((784 * blockIdx.x) + reduce_k)] * Y_reshape[((784 * j_inner) + ((1568 * threadIdx.x) + reduce_k))]));
149:         };
149:       };
149:     };
149:   };
149: }__global__
149: void fn_const_scalar_2_broadcast_to_3_elementwise_add_1_max_4_fused_kernel(const float* __restrict__ var_3, const float* __restrict__ Z, float* __restrict__ max_Out)
149: {
149:   if ((blockIdx.x < 98)) {
149:     if ((threadIdx.x < cinn_nvgpu_min_fp32(1024, (100000 + (-1024 * blockIdx.x))))) {
149:       max_Out[((100 * (((24 * blockIdx.x) + threadIdx.x) / 100)) + ((((24 * blockIdx.x) + threadIdx.x) % 100) + (1000 * blockIdx.x)))] = cinn_nvgpu_max_fp32((var_3[((100 * (((24 * blockIdx.x) + threadIdx.x) / 100)) + ((((24 * blockIdx.x) + threadIdx.x) % 100) + (1000 * blockIdx.x)))] + Z[(((24 * blockIdx.x) + threadIdx.x) % 100)]), 0);
149:     };
149:   };
149: }
149: 
149: }
149: I1024 05:57:31.311307 24525 nvrtc_util.cc:94] compile options: -arch=compute_70 --include-path=/usr/local/cuda/include --include-path=/work/Develop/sync_work/Paddle/build_venv/third_party/CINN/src/external_cinn/cinn/runtime/cuda
149: E1024 05:57:32.077482 24525 lower_impl.cc:358] tensor [mul_mkl_out] buffer is null
149: [       OK ] CinnCompilerTest.Compile (1538 ms)
149: [----------] 1 test from CinnCompilerTest (1538 ms total)
149: 
149: [----------] Global test environment tear-down
149: [==========] 1 test from 1 test case ran. (1538 ms total)
149: [  PASSED  ] 1 test.
1/1 Test #149: cinn_compiler_test ...............   Passed    1.85 sec

The following tests passed:
	cinn_compiler_test

100% tests passed, 0 tests failed out of 1

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@wzzju wzzju force-pushed the get_graph branch 2 times, most recently from 17fa20b to 08cb6f3 Compare October 19, 2021 14:11
@wzzju wzzju marked this pull request as draft October 19, 2021 14:30
@wzzju wzzju marked this pull request as ready for review October 21, 2021 13:16
@wzzju wzzju changed the title Add the functions of CinnCompiler. Create CinnCompiler class for compiling subgraphs found by build_cinn_pass. Oct 22, 2021
// from CinnCacheKey to CinnCompiledObject. If cache hits, we will re-use cache
// stored CinnCompiledObject, otherwise we will compile again and put into
// cache.
class CinnCompiler {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't think it is a good name.

  1. The class is an interface/runner/controller to run the CINN, so the class is not the actual CinnCompiler.
  2. Cinn stands for "Compiler Infrastructure for Neural Networks", it already has the meaning of compiler, so I would think we should use other term to replace the Compiler.

Copy link
Contributor Author

@wzzju wzzju Oct 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢评论。之前叫CinnRuner是因为当时希望执行也放在这个类中,但是现在这个类只承载着编译的功能,所以改名叫CinnCompiler,这也是借鉴了其他竞品的做法。个人认为,CINN虽然有编译器的含义,但是在这里我们可以将其当做一个专有名词,当然这也只是个人想法。具体细节我们可以线下讨论,再次感谢。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我个人认为我提的1那个点是我更个人支持的,这个类是个调用cinn的类,不是那个真正的cinn compiler,所以我感觉叫个什么controller也会好些。不过其实CinnCompiler这个名字,我也不是完全反对,其实要起这个名字我也行,哈哈

return graph_key;
}

Graph* CinnCompiler::FindGraph(const std::string& graph_key) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use shared_ptr<Graph> in graphs_ and function return type? If out caller gets raw pointer and we update the graphs_ in the future, I'm afraid that unique_ptr may be destroyed.

Personally I prefer not to mix raw pointer and smart pointer in same unit.

Copy link
Contributor Author

@wzzju wzzju Oct 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢评论。个人认为,使用shaped_ptr还是unique_ptr取决于是否能够确定所有权归哪个类所有。即其所有权是共享的还是明确可以被哪个类所拥有。这里其实可以明确找到的待编译子图归CinnCompiler所有。另外,因为build_cinn_pass也仅执行一次,所以这里也不会更新graphs_,其实如果真的会更新graphs_,使用shaped_ptr貌似也不能彻底解决这个问题,如果是多线程环境可能需要加锁来完善。

为了防止外部用户误用指针(比如delete等),根据您的提示最终选择返回const引用。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

Const reference is good to me. My little concern is the usage is kind of sharing to me so I just suggest shared_ptr, but const reference is already good to me.

return graphs_.at(graph_key).get();
}

CinnCompiledObject* CinnCompiler::Compile(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, should we use shared_ptr in cache_ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢评论。回复同上,因为可以明确cache_中的结果CinnCompiledObject归CinnCompiler所有,所以是unique_ptr更合理,这也参考了其他竞品的做法。为了防止外部用户误用指针(比如delete等),根据您的提示最终选择返回const引用。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

@wzzju wzzju merged commit 4c46037 into PaddlePaddle:develop Oct 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants