Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InferSymbolicShape for pd_op.meshgrid #62710

Merged
merged 11 commits into from
Apr 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,38 @@ bool LogspaceOpInferSymbolicShape(
return true;
}

bool MeshgridOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
const symbol::TensorListShapeOrDataDimExprs &shape_data_list =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0))
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

const symbol::ShapeOrDataDimExprs sym_shape_dim_exprs = [&] {
symbol::TensorListShapeOrDataDimExprs shape_dim_exprs_list;
std::vector<symbol::DimExpr> vec;

for (auto &shape_data : shape_data_list) {
if (shape_data.shape().size() == 0) {
vec.emplace_back(1);
} else {
vec.emplace_back(shape_data.shape()[0]);
}
}

auto shape_dim_exprs = symbol::TensorShapeOrDataDimExprs(vec);

for (size_t i = 0; i < shape_data_list.size(); i++) {
shape_dim_exprs_list.emplace_back(shape_dim_exprs);
}

return symbol::ShapeOrDataDimExprs(shape_dim_exprs_list);
}();

pir::Value res = op->result(0);
shape_analysis->SetShapeOrDataForValue(res, sym_shape_dim_exprs);
return true;
}

bool StackOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,7 @@
func : meshgrid
data_type : inputs
backward : meshgrid_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : mode
args : (Tensor x, int axis = -1, bool keepdim = false)
Expand Down
55 changes: 55 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_multinary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,61 @@ def test_eval_symbolic(self):
return out


class MeshgridNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x, y):
out_x, out_y = paddle.meshgrid(x, y)
return out_x, out_y


class MeshgridOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.x_cases = [
np.random.rand(1),
np.random.rand(10),
np.random.rand(100),
np.random.rand(1000),
]
self.y_cases = [
np.random.rand(1),
np.random.rand(10),
np.random.rand(1000),
np.random.rand(100),
]

self.expected = [
'shape[S0, S1], data[NULL], shape[S0, S1], data[NULL]',
]

def test_eval_symbolic(self):
net = MeshgridNet()

for i in range(len(self.x_cases)):
x = self.x_cases[i]
y = self.y_cases[i]
x_spec = InputSpec(
shape=[None for index in range(len(x.shape))], dtype='float32'
)
y_spec = InputSpec(
shape=[None for index in range(len(y.shape))], dtype='float32'
)

input_spec = [x_spec, y_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(
net, input_spec, 'pd_op.meshgrid', self.expected
)

# TODO(WintersMontagne10335): Add builtin.meshgrid op infer symbolic shape test
# Not added because attribute `sym_shape_str` does not support multi-output op now.
# See also: paddle/fluid/pir/transforms/shape_optimization_pass.cc:144.

return True


class SliceNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand Down