Skip to content

Commit

Permalink
【PIR OpTest Fix No.13】 Fix test_partial_concat_op (#62833)
Browse files Browse the repository at this point in the history
* [PIR] fix test_partial_concat_op

* [PIR] fix test_partial_concat_op

* [PIR] fix test_partial_concat_op

* fix_infermeta

* fix conflict

* fix conflict

* fix code style
  • Loading branch information
cmcamdy authored Mar 21, 2024
1 parent 984b284 commit 70fba62
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
'prune_gate_by_capacity',
'push_sparse_v2',
'push_sparse_v2_',
'partial_concat',
'partial_send',
'partial_recv',
'partial_allgather',
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,16 @@
func : partial_allgather
inplace : (x -> out)

- op : partial_concat
args : (Tensor[] x, int start_index = 0, int length = -1)
output : Tensor(out)
infer_meta :
func : PartialConcatInferMeta
kernel :
func : partial_concat
data_type : x
backward : partial_concat_grad

- op : partial_recv
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,16 @@
composite : pad_grad(x, out_grad, paddings, pad_value, x_grad)
backward : pad_double_grad

- backward_op : partial_concat_grad
forward : partial_concat (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : PartialConcatGradInferMeta
param : [x]
kernel :
func : partial_concat_grad

- backward_op : partial_sum_grad
forward : partial_sum (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int start_index, int length)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ const std::unordered_set<std::string> LegacyOpList = {
SoftReluGradOp::name(),
MatchMatrixTensorOp::name(),
MatchMatrixTensorGradOp::name(),
PartialConcatOp::name(),
PartialConcatGradOp::name(),
NceOp::name(),
NceGradOp::name(),
PartialSumOp::name(),
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,15 @@
outputs :
out : Out

- op : partial_concat
backward : partial_concat_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]

- op : partial_recv
outputs :
out : Out
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,16 @@ void NanmedianGradInferMeta(const MetaTensor& x,
x_grad->set_dtype(x.dtype());
}

void PartialConcatGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads) {
auto input_num = xs.size();
for (size_t i = 0; i < input_num; i++) {
auto x_dims = xs[i]->dims();
x_grads[i]->set_dims(x_dims);
x_grads[i]->set_dtype(xs[i]->dtype());
}
}

void NceGradInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const MetaTensor& weight,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ void NanmedianGradInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* x_grad);

void PartialConcatGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

void PartialSumGradInferMeta(const std::vector<const MetaTensor*>& xs,
std::vector<MetaTensor*> x_grads);

Expand Down
71 changes: 71 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4543,6 +4543,77 @@ void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
out->set_dtype(xs[0]->dtype());
}

void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config) {
int64_t batch_size = -1;
int64_t input_len = -1;

auto inputs_num = xs.size();
PADDLE_ENFORCE_GT(inputs_num,
0,
phi::errors::InvalidArgument(
"ShapeError: Input tensors count should > 0. But "
"received inputs' length is 0."));

// Only support two dimensions now, should be extended later
// when length is -1, need make sure all dimensions to be added are the same
for (size_t i = 0; i < inputs_num; i++) {
auto x_dim = xs[i]->dims();

PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
phi::errors::InvalidArgument("Only support two dimensions input now."));

if (i == 0) {
batch_size = x_dim[0];
input_len = x_dim[1];
} else {
// each tensor's dim must eq
PADDLE_ENFORCE_EQ(x_dim[0],
batch_size,
phi::errors::InvalidArgument(
"The batch size of all inputs must be same"));
PADDLE_ENFORCE_EQ(x_dim[1],
input_len,
phi::errors::InvalidArgument(
"The input len of all inputs must be same"));
}
}

PADDLE_ENFORCE_EQ(
start_index >= -input_len && start_index < input_len,
true,
phi::errors::InvalidArgument(
"The start_index is expected to be in range of [%d, %d), but got %d",
-input_len,
input_len,
start_index));

if (start_index < 0) {
start_index += input_len;
}

if (length > 0) {
PADDLE_ENFORCE_GE(input_len,
start_index + length,
phi::errors::OutOfRange(
"start_index + length is larger than input length"));
}

std::vector<int64_t> out_dims(2);
out_dims[0] = batch_size;
// colnum = input_num * length
out_dims[1] = (length < 0) ? input_len - start_index : length;
out_dims[1] *= inputs_num;
DDim out_dim = common::make_ddim(out_dims);
out->set_dims(out_dim);
out->set_dtype(xs[0]->dtype());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,12 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialConcatInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialSumInferMeta(const std::vector<const MetaTensor*>& xs,
int start_index,
int length,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ test_one_hot_v2_op
test_one_hot_v2_op_static_build
test_overlap_add_op
test_pad3d_op
test_partial_concat_op
test_partial_sum_op
test_pass_quantization
test_pixel_shuffle_op
Expand Down

0 comments on commit 70fba62

Please sign in to comment.