-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cpu kernel of new api : lstsq #38585
Add cpu kernel of new api : lstsq #38585
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm for parallel_UT_rule.py
res_singular_values = results[3].numpy() | ||
|
||
if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]: | ||
if (np.abs(res_residuals - self._output_residuals) < 1e-6).any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use np.allclose instead of np.abs and any ? generally, set rtol/atol to 1e-6/0
the same as other np.abs below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will fix them in next PR.
fetch_list=[results]) | ||
|
||
if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]: | ||
if (np.abs(fetches[1] - self._output_residuals) < 1e-6).any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use np.allclose instead of np.abs and any ? generally, set rtol/atol to 1e-6/0
the same as other np.abs below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will fix them in next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no GPU implementation?
@@ -2503,3 +2502,107 @@ def __check_input(x, UPLO): | |||
attrs={'UPLO': UPLO, | |||
'is_test': is_test}) | |||
return out_value | |||
|
|||
|
|||
def lstsq(x, y, rcond=1e-15, driver=None, name=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imlementation of GPU kernel and docs will be added in next PR.
self.rcond = 1e-15 | ||
self.driver = "gelss" | ||
self._input_shape_1 = (50, 600) | ||
self._input_shape_2 = (50, 300) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The UT need to check the compute correctness
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correctness of computation is checked in UT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
Add cpu kernel of new api : lstsq