Skip to content

Commit 27127d6

Browse files
[Prim][PIR] support huber_loss op forward in prim pir (#64425)
* update huber_loss * update dynamic
1 parent a08ad07 commit 27127d6

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"squeeze",
5757
"stack",
5858
"unsqueeze",
59+
"huber_loss",
5960
]
6061

6162
# come into effect in generated file op_decomp.cc
@@ -95,6 +96,7 @@
9596
"squeeze",
9697
"stack",
9798
"unsqueeze",
99+
"huber_loss",
98100
]
99101

100102
# xshape output will no longer used after decomp, but return none to keep output num the same as origin op

paddle/fluid/primitive/composite/composite.h

+19
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,25 @@ Tensor pow_decomp(const Tensor& x, const paddle::Scalar& y) {
208208
}
209209
}
210210

211+
template <typename T>
212+
std::tuple<Tensor, Tensor> huber_loss_decomp(const Tensor& input,
213+
const Tensor& label,
214+
float delta) {
215+
Tensor delta_full;
216+
if (has_dynamic_shape(input.shape())) {
217+
delta_full =
218+
backend::full_with_tensor<T>(shape<T>(input), delta, input.dtype());
219+
} else {
220+
delta_full = full<T>(input.shape(), delta, input.dtype());
221+
}
222+
auto val = label - input;
223+
auto abs_val = abs<T>(val);
224+
auto ans = where<T>(abs_val <= delta_full,
225+
0.5 * val * val,
226+
delta_full * (abs_val - 0.5 * delta_full));
227+
return std::make_tuple(ans, val);
228+
}
229+
211230
template <typename T>
212231
Tensor one_hot_decomp(const Tensor& x, const Tensor& num_classes) {
213232
auto num_classes_tensor =

paddle/phi/infermeta/binary.cc

+2
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,8 @@ void HuberLossInferMeta(const MetaTensor& input,
20692069
auto out_dims = label_dims;
20702070
residual->set_dims(out_dims);
20712071
out->set_dims(out_dims);
2072+
out->set_dtype(input.dtype());
2073+
residual->set_dtype(input.dtype());
20722074
out->share_lod(input);
20732075
}
20742076

test/deprecated/legacy_test/test_huber_loss_op.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,18 @@ def huber_loss_forward(val, delta):
2929
return delta * (abs_val - 0.5 * delta)
3030

3131

32-
def huber_loss_wraper(x, y, delta):
32+
def huber_loss_wrapper(x, y, delta):
3333
a = paddle._C_ops.huber_loss(x, y, delta)
3434
return a
3535

3636

3737
class TestHuberLossOp(OpTest):
3838
def setUp(self):
3939
self.op_type = 'huber_loss'
40+
self.prim_op_type = "comp"
4041
self.python_out_sig = ["Out"]
41-
self.python_api = huber_loss_wraper
42+
self.python_api = huber_loss_wrapper
43+
self.public_python_api = huber_loss_wrapper
4244

4345
self.delta = 1.0
4446
self.init_dtype()
@@ -65,7 +67,7 @@ def set_shape(self):
6567
return (100, 1)
6668

6769
def test_check_output(self):
68-
self.check_output()
70+
self.check_output(check_prim_pir=True)
6971

7072
def test_check_grad_normal(self):
7173
self.check_grad(['X', 'Y'], 'Out')
@@ -105,8 +107,10 @@ def init_dtype(self):
105107
class TestHuberLossBF16Op(OpTest):
106108
def setUp(self):
107109
self.op_type = 'huber_loss'
110+
self.prim_op_type = "comp"
108111
self.python_out_sig = ["Out"]
109-
self.python_api = huber_loss_wraper
112+
self.python_api = huber_loss_wrapper
113+
self.public_python_api = huber_loss_wrapper
110114

111115
self.delta = 1.0
112116
self.init_dtype()
@@ -142,7 +146,7 @@ def set_shape(self):
142146
return (100, 1)
143147

144148
def test_check_output(self):
145-
self.check_output_with_place(self.place)
149+
self.check_output_with_place(self.place, check_prim_pir=True)
146150

147151
def test_check_grad_normal(self):
148152
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

+21
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def index_sample_net(x, index):
7474
return paddle.index_sample(x, index)
7575

7676

77+
def huber_loss_net(x, label):
78+
return paddle._C_ops.huber_loss(x, label, 1.0)
79+
80+
7781
def bce_loss_net(x, label):
7882
return paddle._C_ops.bce_loss(x, label)
7983

@@ -322,6 +326,23 @@ def setUp(self):
322326
self.tol = 1e-6
323327

324328

329+
class TestPrimHuberLoss(TestPrimTwo):
330+
def setUp(self):
331+
np.random.seed(2023)
332+
self.x_shape = [100, 1]
333+
self.y_shape = [100, 1]
334+
self.dtype_x = "float32"
335+
self.dtype_y = "float32"
336+
self.init_x_shape = [None, None]
337+
self.init_y_shape = [None, None]
338+
self.x = np.random.uniform(0, 1.0, self.x_shape).astype(self.dtype_x)
339+
self.y = np.random.uniform(0, 1.0, self.y_shape).astype(self.dtype_y)
340+
self.net = huber_loss_net
341+
self.necessary_ops = "pd_op.huber_loss"
342+
self.enable_cinn = False
343+
self.tol = 1e-6
344+
345+
325346
class TestPrimBceLoss(TestPrimTwo):
326347
def setUp(self):
327348
np.random.seed(2023)

0 commit comments

Comments
 (0)