Skip to content

Commit

Permalink
fix MKL-based FFT implementation (PaddlePaddle#44)
Browse files Browse the repository at this point in the history
* fix MKL-based FFT implementation, MKL CDFT's FORWARD DOMAIN is always REAL for R2C and C2R
  • Loading branch information
Feiyu Chan authored Sep 15, 2021
1 parent c0289d1 commit b3d5f13
Showing 1 changed file with 62 additions and 18 deletions.
80 changes: 62 additions & 18 deletions paddle/fluid/operators/spectral_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,20 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
}
}();

const bool complex_input = framework::IsComplexType(in_dtype);
const bool complex_output = framework::IsComplexType(out_dtype);
const DFTI_CONFIG_VALUE domain = [&] {
if (forward) {
return complex_input ? DFTI_COMPLEX : DFTI_REAL;
} else {
return complex_output ? DFTI_COMPLEX : DFTI_REAL;
}
}();
// C2C, R2C, C2R
const FFTTransformType fft_type = GetFFTTransformType(in_dtype, out_dtype);
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;

// const bool complex_input = framework::IsComplexType(in_dtype);
// const bool complex_output = framework::IsComplexType(out_dtype);
// const DFTI_CONFIG_VALUE domain = [&] {
// if (forward) {
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
// } else {
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
// }
// }();

DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
Expand Down Expand Up @@ -442,7 +447,7 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_out_stride.data()));

// conjugate even storage
if (!complex_input || !complex_output) {
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE,
DFTI_COMPLEX_COMPLEX));
}
Expand All @@ -455,8 +460,16 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
((normalization == FFTNormMode::by_sqrt_n)
? 1.0 / std::sqrt(static_cast<double>(signal_numel))
: 1.0 / static_cast<double>(signal_numel));
const auto scale_direction =
forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
const auto scale_direction = [&]() {
if (fft_type == FFTTransformType::R2C ||
(fft_type == FFTTransformType::C2C && forward)) {
return DFTI_FORWARD_SCALE;
} else {
// (fft_type == FFTTransformType::C2R ||
// (fft_type == FFTTransformType::C2C && !forward))
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
}

Expand Down Expand Up @@ -541,13 +554,44 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
DftiDescriptor desc =
_plan_mkl_fft(x->type(), out->type(), input_stride, output_stride,
signal_sizes, normalization, forward);
// dump_descriptor(desc.get());
if (forward) {
dump_descriptor(desc.get());

const FFTTransformType fft_type = GetFFTTransformType(x->type(), out->type());
if (fft_type == FFTTransformType::C2R && forward) {
framework::Tensor collapsed_input_conj(collapsed_input.type());
collapsed_input_conj.mutable_data<Ti>(collapsed_input.dims(),
ctx.GetPlace());
// conjugate the input
platform::ForRange<DeviceContext> for_range(ctx, collapsed_input.numel());
math::ConjFunctor<Ti> functor(collapsed_input.data<Ti>(),
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input_conj.data<void>(),
collapsed_output.data<void>()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.type());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
collapsed_output_conj.data<void>()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
collapsed_output.numel(),
collapsed_output.data<To>());
for_range(functor);
} else {
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
if (forward) {
MKL_DFTI_CHECK(DftiComputeForward(desc.get(),
collapsed_input.data<void>(),
collapsed_output.data<void>()));
} else {
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input.data<void>(),
collapsed_output.data<void>()));
}
}

// resize for the collapsed output
Expand Down Expand Up @@ -598,7 +642,7 @@ struct FFTC2RFunctor<platform::CPUDeviceContext, Ti, To> {
FFTC2CFunctor<platform::CPUDeviceContext, Ti, Ti> c2c_functor;
c2c_functor(ctx, x, &temp, c2c_dims, normalization, forward);

const std::vector<int64_t> new_axes(axes.back());
const std::vector<int64_t> new_axes{axes.back()};
exec_fft<platform::CPUDeviceContext, Ti, To>(ctx, &temp, out, new_axes,
normalization, forward);
} else {
Expand Down

0 comments on commit b3d5f13

Please sign in to comment.