Skip to content

Commit 6751f6d

Browse files
zbt78co63oc
authored andcommitted
【complex op No.25】add complex support for cross (PaddlePaddle#63207)
* add complex dtype for cross * remove temp var when dtype is not complex
1 parent 9437361 commit 6751f6d

File tree

6 files changed

+142
-24
lines changed

6 files changed

+142
-24
lines changed

paddle/phi/kernels/cpu/cross_grad_kernel.cc

+25-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "paddle/phi/core/dense_tensor.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/core/tensor_utils.h"
21+
#include "paddle/phi/kernels/funcs/complex_functors.h"
22+
#include "paddle/phi/kernels/funcs/for_range.h"
2123

2224
namespace phi {
2325

@@ -81,9 +83,27 @@ void CrossGradKernel(const Context &dev_ctx,
8183
slice_size *= static_cast<int>(input_x_dims[i]);
8284
}
8385

86+
int64_t numel = x.numel();
87+
DenseTensor x_conj, y_conj;
88+
DenseTensorMeta meta_xy(x.dtype(), x.dims());
89+
x_conj.set_meta(meta_xy);
90+
y_conj.set_meta(meta_xy);
91+
92+
auto *input_x_conj_data = dev_ctx.template Alloc<T>(&x_conj);
93+
94+
auto *input_y_conj_data = dev_ctx.template Alloc<T>(&y_conj);
95+
96+
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
97+
phi::funcs::ConjFunctor<T> functor_x(
98+
input_x.data<T>(), numel, input_x_conj_data);
99+
phi::funcs::ConjFunctor<T> functor_y(
100+
input_y.data<T>(), numel, input_y_conj_data);
101+
for_range(functor_x);
102+
for_range(functor_y);
103+
84104
std::vector<T> input_x_vec, input_y_vec, input_dout_vec;
85-
phi::TensorToVector(input_x, dev_ctx, &input_x_vec);
86-
phi::TensorToVector(input_y, dev_ctx, &input_y_vec);
105+
phi::TensorToVector(x_conj, dev_ctx, &input_x_vec);
106+
phi::TensorToVector(y_conj, dev_ctx, &input_y_vec);
87107
phi::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec);
88108
std::vector<T> out_dx_vec(output_x_grad->numel());
89109
std::vector<T> out_dy_vec(output_y_grad->numel());
@@ -120,4 +140,6 @@ PD_REGISTER_KERNEL(cross_grad,
120140
float,
121141
double,
122142
int,
123-
int64_t) {}
143+
int64_t,
144+
phi::dtype::complex<float>,
145+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/cross_kernel.cc

+10-2
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,13 @@ void CrossKernel(const Context& dev_ctx,
105105

106106
} // namespace phi
107107

108-
PD_REGISTER_KERNEL(
109-
cross, CPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {}
108+
PD_REGISTER_KERNEL(cross,
109+
CPU,
110+
ALL_LAYOUT,
111+
phi::CrossKernel,
112+
float,
113+
double,
114+
int,
115+
int64_t,
116+
phi::dtype::complex<float>,
117+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/cross_grad_kernel.cu

+47-14
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
1919
#include "paddle/phi/core/dense_tensor.h"
2020
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/kernels/funcs/complex_functors.h"
22+
#include "paddle/phi/kernels/funcs/for_range.h"
2123
#include "paddle/phi/kernels/funcs/index_calculator.h"
2224

2325
namespace phi {
@@ -162,27 +164,56 @@ void CrossGradKernel(const Context& dev_ctx,
162164

163165
const auto* input_x_data = input_x.data<T>();
164166
const auto* input_y_data = input_y.data<T>();
167+
int64_t numel = x.numel();
165168
const auto* input_out_grad_data = input_out_grad.data<T>();
166169
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
167170
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
168171
auto index_calculator = phi::funcs::IndexCalculator(
169172
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
170173

171-
int64_t numel = x.numel();
172174
backends::gpu::GpuLaunchConfig config =
173175
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
174-
175-
CrossGrad<<<config.block_per_grid,
176-
config.thread_per_block,
177-
0,
178-
dev_ctx.stream()>>>(input_x_data,
179-
input_y_data,
180-
input_out_grad_data,
181-
output_x_grad_data,
182-
output_y_grad_data,
183-
full_strides[merge_axis],
184-
numel / 3,
185-
index_calculator);
176+
if (IsComplexType(x.dtype())) {
177+
DenseTensor x_conj, y_conj;
178+
DenseTensorMeta meta_xy(x.dtype(), x.dims());
179+
x_conj.set_meta(meta_xy);
180+
y_conj.set_meta(meta_xy);
181+
182+
auto* input_x_conj_data = dev_ctx.template Alloc<T>(&x_conj);
183+
auto* input_y_conj_data = dev_ctx.template Alloc<T>(&y_conj);
184+
185+
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
186+
phi::funcs::ConjFunctor<T> functor_x(
187+
input_x_data, numel, input_x_conj_data);
188+
phi::funcs::ConjFunctor<T> functor_y(
189+
input_y_data, numel, input_y_conj_data);
190+
for_range(functor_x);
191+
for_range(functor_y);
192+
193+
CrossGrad<<<config.block_per_grid,
194+
config.thread_per_block,
195+
0,
196+
dev_ctx.stream()>>>(input_x_conj_data,
197+
input_y_conj_data,
198+
input_out_grad_data,
199+
output_x_grad_data,
200+
output_y_grad_data,
201+
full_strides[merge_axis],
202+
numel / 3,
203+
index_calculator);
204+
} else {
205+
CrossGrad<<<config.block_per_grid,
206+
config.thread_per_block,
207+
0,
208+
dev_ctx.stream()>>>(input_x_data,
209+
input_y_data,
210+
input_out_grad_data,
211+
output_x_grad_data,
212+
output_y_grad_data,
213+
full_strides[merge_axis],
214+
numel / 3,
215+
index_calculator);
216+
}
186217
}
187218
} // namespace phi
188219

@@ -195,4 +226,6 @@ PD_REGISTER_KERNEL(cross_grad,
195226
float,
196227
double,
197228
int,
198-
int64_t) {}
229+
int64_t,
230+
phi::dtype::complex<float>,
231+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/cross_kernel.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,6 @@ PD_REGISTER_KERNEL(cross,
172172
float,
173173
double,
174174
int,
175-
int64_t) {}
175+
int64_t,
176+
phi::dtype::complex<float>,
177+
phi::dtype::complex<double>) {}

python/paddle/tensor/linalg.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -1900,8 +1900,8 @@ def cross(x, y, axis=9, name=None):
19001900
If `axis` is not given, it defaults to the first axis found with the length 3.
19011901
19021902
Args:
1903-
x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64.
1904-
y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64.
1903+
x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64, complex64, complex128.
1904+
y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64, complex64, complex128.
19051905
axis (int, optional): The axis along which to compute the cross product. It defaults to be 9 which indicates using the first axis found with the length 3.
19061906
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
19071907
@@ -1941,13 +1941,31 @@ def cross(x, y, axis=9, name=None):
19411941
check_variable_and_dtype(
19421942
x,
19431943
'x',
1944-
['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
1944+
[
1945+
'float16',
1946+
'uint16',
1947+
'float32',
1948+
'float64',
1949+
"int32",
1950+
"int64",
1951+
"complex64",
1952+
"complex128",
1953+
],
19451954
'cross',
19461955
)
19471956
check_variable_and_dtype(
19481957
y,
19491958
'y',
1950-
['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
1959+
[
1960+
'float16',
1961+
'uint16',
1962+
'float32',
1963+
'float64',
1964+
"int32",
1965+
"int64",
1966+
"complex64",
1967+
"complex128",
1968+
],
19511969
'cross',
19521970
)
19531971
helper = LayerHelper("cross", **locals())

test/legacy_test/test_cross_op.py

+35
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ def setUp(self):
3232
'X': np.random.random(self.shape).astype(self.dtype),
3333
'Y': np.random.random(self.shape).astype(self.dtype),
3434
}
35+
if self.dtype is np.complex64 or self.dtype is np.complex128:
36+
self.inputs = {
37+
'X': (
38+
np.random.random(self.shape)
39+
+ 1j * np.random.random(self.shape)
40+
).astype(self.dtype),
41+
'Y': (
42+
np.random.random(self.shape)
43+
+ 1j * np.random.random(self.shape)
44+
).astype(self.dtype),
45+
}
3546
self.init_output()
3647

3748
def initTestCase(self):
@@ -81,6 +92,30 @@ def init_output(self):
8192
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
8293

8394

95+
class TestCrossComplex64Op(TestCrossOp):
96+
def initTestCase(self):
97+
self.shape = (2048, 3)
98+
self.dtype = np.complex64
99+
100+
def init_output(self):
101+
z_list = []
102+
for i in range(2048):
103+
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
104+
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
105+
106+
107+
class TestCrossComplex128Op(TestCrossOp):
108+
def initTestCase(self):
109+
self.shape = (2048, 3)
110+
self.dtype = np.complex128
111+
112+
def init_output(self):
113+
z_list = []
114+
for i in range(2048):
115+
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
116+
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
117+
118+
84119
@unittest.skipIf(
85120
not core.is_compiled_with_cuda()
86121
or not core.is_bfloat16_supported(core.CUDAPlace(0)),

0 commit comments

Comments
 (0)