Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] support huber_loss op forward in prim pir #64425

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"squeeze",
"stack",
"unsqueeze",
"huber_loss",
]

# come into effect in generated file op_decomp.cc
Expand Down Expand Up @@ -95,6 +96,7 @@
"squeeze",
"stack",
"unsqueeze",
"huber_loss",
]


Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,20 @@ Tensor pow_decomp(const Tensor& x, const paddle::Scalar& y) {
}
}

template <typename T>
std::tuple<Tensor, Tensor> huber_loss_decomp(const Tensor& input,
const Tensor& label,
float delta) {
auto delta_full = full<T>(input.shape(), delta, input.dtype());
auto val = label - input;
auto abs_val = abs<T>(val);
auto ans = where<T>(abs_val <= delta_full,
0.5 * val * val,
delta_full * (abs_val - 0.5 * delta_full));
return std::make_tuple(cast<T>(ans, input.dtype()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

适配一下动态shape,以及这里确认一下,cast是否必须?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

cast<T>(val, input.dtype()));
}

template <typename T>
Tensor one_hot_decomp(const Tensor& x, const Tensor& num_classes) {
auto num_classes_tensor =
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,8 @@ void HuberLossInferMeta(const MetaTensor& input,
auto out_dims = label_dims;
residual->set_dims(out_dims);
out->set_dims(out_dims);
out->set_dtype(input.dtype());
residual->set_dtype(input.dtype());
out->share_lod(input);
}

Expand Down
14 changes: 9 additions & 5 deletions test/deprecated/legacy_test/test_huber_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ def huber_loss_forward(val, delta):
return delta * (abs_val - 0.5 * delta)


def huber_loss_wraper(x, y, delta):
def huber_loss_wrapper(x, y, delta):
a = paddle._C_ops.huber_loss(x, y, delta)
return a


class TestHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'huber_loss'
self.prim_op_type = "comp"
self.python_out_sig = ["Out"]
self.python_api = huber_loss_wraper
self.python_api = huber_loss_wrapper
self.public_python_api = huber_loss_wrapper

self.delta = 1.0
self.init_dtype()
Expand All @@ -65,7 +67,7 @@ def set_shape(self):
return (100, 1)

def test_check_output(self):
self.check_output()
self.check_output(check_prim_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
Expand Down Expand Up @@ -105,8 +107,10 @@ def init_dtype(self):
class TestHuberLossBF16Op(OpTest):
def setUp(self):
self.op_type = 'huber_loss'
self.prim_op_type = "comp"
self.python_out_sig = ["Out"]
self.python_api = huber_loss_wraper
self.python_api = huber_loss_wrapper
self.public_python_api = huber_loss_wrapper

self.delta = 1.0
self.init_dtype()
Expand Down Expand Up @@ -142,7 +146,7 @@ def set_shape(self):
return (100, 1)

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_prim_pir=True)

def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
Expand Down