Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the accuracy problem of allclose op when using float64 data type #27891

Merged
merged 8 commits into from
Oct 19, 2020

Conversation

huangxu96
Copy link
Contributor

@huangxu96 huangxu96 commented Oct 13, 2020

PR types

Bug fixes

PR changes

OPs

Describe

This PR fixed a bug in allclose_op, which cannot get the expected output when fp64 as input in some cases.

BUG reproduce process:

paddle.disable_static()

np_x = np.array([10.1]).astype("float64")

np_y = np.array([10]).astype("float64")

x = paddle.to_tensor (np_x)

y = paddle.to_tensor (np_y)

result = paddle.allclose(x=x, y=y, rtol=0.01, atol=0, equal_nan=False, name="ignore_nan")

result = result.numpy()

print(result)

This result is expected to be True but it returns false.

Problem reason

Floating point number cannot be determined equal directly, since floating point number cannot have a precise experssion in
Computer. For example, 0.1 might be 0.09999 or 0.100001 in computer. So when we want to determine two "0.1" if they are equaled, the compute might executes if 0.09999 equal to 0.100001. This is how the false result comes.

Solving approach:

Add an extremely small value (1e-15) when determine if two double varibles are equaled.

Change the date type of rtol and atol from float32 to float64, since the accuracy of rtol and atol will also impact the final result.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

T operator()(const framework::Tensor& tensor) const {
const T* data = tensor.data<T>();
T value;
cudaMemcpy(&value, data, sizeof(T), cudaMemcpyDeviceToHost);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to use memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T), dev_ctx.stream());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +37 to +38
auto* in_a = in.data<T>();
auto* in_b = other.data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use better var name instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T dif = (left > right ? left - right : right - left);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use diff instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
};

template struct AllcloseFunctor<platform::CPUDeviceContext, double>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T dif = (left > right ? left - right : right - left);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use diff instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

atomicAnd(reinterpret_cast<int*>(&val_), static_cast<int>(val));
__syncthreads();
if (tid == 0) {
*out_data = static_cast<bool>(val_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here static_cast is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already used parallel reduction here.

};

template <typename T>
__global__ void AllcloseCUDAKernel(const T* in_a, const T* in_b,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance here is too poor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already used parallel reduction here.

Comment on lines 78 to 79
int grid = 1;
int block = in_dims;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think over and try to instead these ridiculous codes here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -22,19 +22,20 @@ class TestAllcloseOp(OpTest):
def set_args(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the test case below.
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Superjomn
Superjomn previously approved these changes Oct 14, 2020
Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到 allclose 在推理里几乎用不到

这次升级这里暂不考虑兼容

Comment on lines +134 to +140
class TestAllcloseOpFloat64(TestAllcloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float64")
self.other = np.array([10]).astype("float64")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the same unit test for float32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
};

template struct AllcloseFunctor<platform::CUDADeviceContext, double>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@wzzju wzzju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wzzju wzzju merged commit d466893 into PaddlePaddle:develop Oct 19, 2020
huangxu96 added a commit to huangxu96/Paddle that referenced this pull request Oct 19, 2020
* Still has bugs.

* Fixed allclose_op bug, which cannot deal with some cases of fp64 inputs.

* improved CUDA kernel performance.

* Changed CUDA code.

* Fixed a bug in cuda kernel which cannot deal with large dimension input, and added an unittest for it.

* Add a test case for float32 input.
@huangxu96 huangxu96 changed the title Allclose op Allclose op bug fixed Oct 19, 2020
@wzzju wzzju changed the title Allclose op bug fixed Fix the accuracy problem of allclose op when using float64 data type. Oct 19, 2020
@huangxu96 huangxu96 changed the title Fix the accuracy problem of allclose op when using float64 data type. Fixed a bug of allclose op that cannot get the expected output when fp64 data as input in some cases Oct 19, 2020
@huangxu96 huangxu96 changed the title Fixed a bug of allclose op that cannot get the expected output when fp64 data as input in some cases Fix the accuracy problem of allclose op when using float64 data type Oct 19, 2020
wzzju pushed a commit that referenced this pull request Oct 19, 2020
* Fixed allclose_op bug, which cannot deal with some cases of fp64 inputs.

* improved CUDA kernel performance.

* Fixed a bug in cuda kernel which cannot deal with large dimension input, and added an unit test for it.

* Add a test case for float32 input.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants