Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0D-Tensor] CINN support unary op, fix test_activation_op #54216

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,42 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"greater_than",
"greater_equal",
"less_than",
"less_equal"};
"less_equal",
"tanh",
"relu",
"gelu",
"sigmoid",
"exp",
"erf",
"rsqrt",
"log",
"log2",
"log10",
"floor",
"ceil",
"round",
"trunc",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"acosh",
"atanh",
"isnan",
"isfinite",
"isinf",
"negative",
"sign",
"abs",
"reciprocal",
"logical_not",
"bitwise_not"};

std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op
// test
Expand Down
50 changes: 8 additions & 42 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ class TestExpPrim_ZeroDim(TestExpFp32_Prim):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestExpm1(TestActivation):
def setUp(self):
Expand Down Expand Up @@ -277,9 +274,6 @@ class TestSigmoid_ZeroDim(TestSigmoid):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


@unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
Expand Down Expand Up @@ -362,9 +356,6 @@ class TestSilu_ZeroDim(TestSilu):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
Expand Down Expand Up @@ -527,9 +518,6 @@ class TestTanh_ZeroDim(TestTanh):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestTanhAPI(unittest.TestCase):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
Expand Down Expand Up @@ -1237,9 +1225,6 @@ class TestSqrt_ZeroDim(TestSqrt):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


@unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
Expand Down Expand Up @@ -1428,9 +1413,6 @@ class TestAbs_ZeroDim(TestAbs):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestCeil(TestActivation):
def setUp(self):
Expand Down Expand Up @@ -1509,9 +1491,6 @@ class TestFloor_ZeroDim(TestFloor):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestCos(TestActivation):
def setUp(self):
Expand All @@ -1521,8 +1500,7 @@ def setUp(self):
self.prim_op_type = "prim"
self.init_dtype()
self.init_shape()
# prim not support now
self.enable_cinn = False
self.if_enable_cinn()

np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
Expand All @@ -1539,6 +1517,9 @@ def test_check_grad(self):
return
self.check_grad(['X'], 'Out', check_prim=True)

def if_enable_cinn(self):
pass


class TestCos_ZeroDim(TestCos):
def init_shape(self):
Expand Down Expand Up @@ -1659,8 +1640,7 @@ def setUp(self):
self.prim_op_type = "prim"
self.init_dtype()
self.init_shape()
# prim not support now
self.enable_cinn = False
self.if_enable_cinn()

np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
Expand All @@ -1677,6 +1657,9 @@ def test_check_grad(self):
return
self.check_grad(['X'], 'Out', check_prim=True)

def if_enable_cinn(self):
pass


class TestSin_ZeroDim(TestSin):
def init_shape(self):
Expand Down Expand Up @@ -1862,9 +1845,6 @@ class TestRelu_ZeroDim(TestRelu):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu
Expand Down Expand Up @@ -2141,9 +2121,6 @@ class TestGelu_ZeroDim(TestGelu):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestGELUAPI(unittest.TestCase):
# test paddle.nn.GELU, paddle.nn.functional.gelu
Expand Down Expand Up @@ -2396,7 +2373,6 @@ def setUp(self):
self.outputs = {'Out': out}
self.convert_input_output()
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.enable_cinn = False

def init_shape(self):
self.shape = [10, 12]
Expand All @@ -2417,10 +2393,6 @@ def test_check_output(self):


class TestHardSwish_ZeroDim(TestHardSwish):
def setUp(self):
super().setUp()
self.enable_cinn = False

def init_shape(self):
self.shape = []

Expand Down Expand Up @@ -2831,9 +2803,6 @@ class TestLog_ZeroDim(TestLog):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestLog2(TestActivation):
def setUp(self):
Expand Down Expand Up @@ -3131,9 +3100,6 @@ class TestPow_ZeroDim(TestPow):
def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False


class TestPow_factor_tensor(TestActivation):
def setUp(self):
Expand Down