diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 110928d060585a..c5feff0f22e4a9 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -907,6 +907,7 @@ func : TrilInferMeta kernel : func : tril + inplace: (x -> out) backward : tril_grad - op : tril_indices @@ -928,6 +929,7 @@ func : TriuInferMeta kernel : func : triu + inplace: (x -> out) backward : triu_grad - op : triu_indices diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 661de64990ee6b..6e603c0dd8a90e 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -665,11 +665,12 @@ - op : digamma args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : digamma + inplace: (x -> out) backward : digamma_grad - op : dirichlet @@ -1095,12 +1096,13 @@ - op : hardtanh args : (Tensor x, float t_min=0, float t_max=24) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : hardtanh + inplace: (x -> out) backward : hardtanh_grad - op : heaviside @@ -1137,6 +1139,7 @@ func : UnchangedInferMeta kernel : func : i0 + inplace: (x -> out) backward : i0_grad - op : i0e @@ -1349,12 +1352,13 @@ - op : leaky_relu args : (Tensor x, float negative_slope = 0.02f) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : leaky_relu + inplace: (x -> out) backward : leaky_relu_grad - op : lerp @@ -1374,6 +1378,7 @@ func : UnchangedInferMeta kernel : func : lgamma + inplace: (x -> out) backward : lgamma_grad - op : linear_interp @@ -1401,38 +1406,42 @@ - op : log args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : log + inplace: (x -> out) backward: log_grad - op : log10 args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : log10 + inplace: (x -> out) backward: log10_grad - op : log1p args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : log1p + inplace: (x -> out) backward: log1p_grad - op : log2 args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : log2 + inplace: (x -> out) backward: log2_grad - op : log_loss @@ -1505,12 +1514,13 @@ - op : logit args : (Tensor x, float eps = 1e-6f) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : logit + inplace: (x -> out) backward : logit_grad - op : logsigmoid @@ -1883,6 +1893,7 @@ param: [x] kernel : func : polygamma + inplace: (x -> out) backward : polygamma_grad - op : pow @@ -2472,12 +2483,13 @@ - op : thresholded_relu args : (Tensor x, float threshold = 1.0) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : thresholded_relu + inplace: (x -> out) backward : thresholded_relu_grad - op : topk @@ -2524,11 +2536,12 @@ - op : trunc args : (Tensor input) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta kernel : func : trunc + inplace: (input -> out) backward : trunc_grad - op : unbind diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index ae5af6ce9ae478..59f3e74db93bdb 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2032,7 +2032,7 @@ struct LogFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.template cast().unaryExpr(Log()); + out.device(d) = x.template cast().unaryExpr(Log()).eval(); } }; @@ -2076,7 +2076,7 @@ struct Log2Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.template cast().unaryExpr(Log2()); + out.device(d) = x.template cast().unaryExpr(Log2()).eval(); } }; @@ -2121,7 +2121,7 @@ struct Log10Functor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.template cast().unaryExpr(Log10()); + out.device(d) = x.template cast().unaryExpr(Log10()).eval(); } }; @@ -2166,7 +2166,7 @@ struct Log1pFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - out.device(d) = x.template cast().unaryExpr(Log1p()); + out.device(d) = x.template cast().unaryExpr(Log1p()).eval(); } }; diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 99c1aa35fd671c..c2c8dd78de8fc4 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -110,7 +110,9 @@ from .tensor.creation import full # noqa: F401 from .tensor.creation import full_like # noqa: F401 from .tensor.creation import triu # noqa: F401 +from .tensor.creation import triu_ # noqa: F401 from .tensor.creation import tril # noqa: F401 +from .tensor.creation import tril_ # noqa: F401 from .tensor.creation import meshgrid # noqa: F401 from .tensor.creation import empty # noqa: F401 from .tensor.creation import empty_like # noqa: F401 @@ -224,14 +226,18 @@ from .tensor.math import cumprod # noqa: F401 from .tensor.math import logcumsumexp # noqa: F401 from .tensor.math import logit # noqa: F401 +from .tensor.math import logit_ # noqa: F401 from .tensor.math import exp # noqa: F401 from .tensor.math import expm1 # noqa: F401 from .tensor.math import expm1_ # noqa: F401 from .tensor.math import floor # noqa: F401 from .tensor.math import increment # noqa: F401 from .tensor.math import log # noqa: F401 +from .tensor.math import log_ # noqa: F401 +from .tensor.math import log2_ # noqa: F401 from .tensor.math import log2 # noqa: F401 from .tensor.math import log10 # noqa: F401 +from .tensor.math import log10_ # noqa: F401 from .tensor.math import multiplex # noqa: F401 from .tensor.math import pow # noqa: F401 from .tensor.math import pow_ # noqa: F401 @@ -279,6 +285,7 @@ from .tensor.math import logaddexp # noqa: F401 from .tensor.math import inverse # noqa: F401 from .tensor.math import log1p # noqa: F401 +from .tensor.math import log1p_ # noqa: F401 from .tensor.math import erf # noqa: F401 from .tensor.math import erf_ # noqa: F401 from .tensor.math import addmm # noqa: F401 @@ -294,9 +301,13 @@ from .tensor.math import broadcast_shape # noqa: F401 from .tensor.math import conj # noqa: F401 from .tensor.math import trunc # noqa: F401 +from .tensor.math import trunc_ # noqa: F401 from .tensor.math import digamma # noqa: F401 +from .tensor.math import digamma_ # noqa: F401 from .tensor.math import neg # noqa: F401 +from .tensor.math import neg_ # noqa: F401 from .tensor.math import lgamma # noqa: F401 +from .tensor.math import lgamma_ # noqa: F401 from .tensor.math import acosh # noqa: F401 from .tensor.math import acosh_ # noqa: F401 from .tensor.math import asinh # noqa: F401 @@ -317,6 +328,7 @@ from .tensor.math import outer # noqa: F401 from .tensor.math import heaviside # noqa: F401 from .tensor.math import frac # noqa: F401 +from .tensor.math import frac_ # noqa: F401 from .tensor.math import sgn # noqa: F401 from .tensor.math import take # noqa: F401 from .tensor.math import frexp # noqa: F401 @@ -326,10 +338,12 @@ from .tensor.math import vander # noqa: F401 from .tensor.math import nextafter # noqa: F401 from .tensor.math import i0 # noqa: F401 +from .tensor.math import i0_ # noqa: F401 from .tensor.math import i0e # noqa: F401 from .tensor.math import i1 # noqa: F401 from .tensor.math import i1e # noqa: F401 from .tensor.math import polygamma # noqa: F401 +from .tensor.math import polygamma_ # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -473,6 +487,7 @@ 'logaddexp', 'logcumsumexp', 'logit', + 'logit_', 'LazyGuard', 'sign', 'is_empty', @@ -561,6 +576,7 @@ 'rand', 'less_equal', 'triu', + 'triu_', 'sin', 'sin_', 'dist', @@ -582,6 +598,7 @@ 'abs', 'abs_', 'tril', + 'tril_', 'pow', 'pow_', 'zeros_like', @@ -608,7 +625,9 @@ 'broadcast_shape', 'conj', 'neg', + 'neg_', 'lgamma', + 'lgamma_', 'lerp', 'erfinv', 'inner', @@ -693,13 +712,19 @@ 'floor', 'cosh', 'log', + 'log_', 'log2', + 'log2_', 'log10', + 'log10_', 'concat', 'check_shape', 'trunc', + 'trunc_', 'frac', + 'frac_', 'digamma', + 'digamma_', 'standard_normal', 'diagonal', 'broadcast_tensors', @@ -741,8 +766,10 @@ 'unflatten', 'nextafter', 'i0', + 'i0_', 'i0e', 'i1', 'i1e', 'polygamma', + 'polygamma_', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index ec5ee96e3cc916..d4514479ca3430 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -21,9 +21,11 @@ from .activation import gelu # noqa: F401 from .activation import hardshrink # noqa: F401 from .activation import hardtanh # noqa: F401 +from .activation import hardtanh_ # noqa: F401 from .activation import hardsigmoid # noqa: F401 from .activation import hardswish # noqa: F401 from .activation import leaky_relu # noqa: F401 +from .activation import leaky_relu_ # noqa: F401 from .activation import log_sigmoid # noqa: F401 from .activation import maxout # noqa: F401 from .activation import prelu # noqa: F401 @@ -44,6 +46,7 @@ from .activation import tanh_ # noqa: F401 from .activation import tanhshrink # noqa: F401 from .activation import thresholded_relu # noqa: F401 +from .activation import thresholded_relu_ # noqa: F401 from .activation import log_softmax # noqa: F401 from .activation import glu # noqa: F401 from .activation import gumbel_softmax # noqa: F401 @@ -153,9 +156,11 @@ 'gelu', 'hardshrink', 'hardtanh', + 'hardtanh_', 'hardsigmoid', 'hardswish', 'leaky_relu', + 'leaky_relu_', 'log_sigmoid', 'maxout', 'prelu', @@ -176,6 +181,7 @@ 'tanh_', 'tanhshrink', 'thresholded_relu', + 'thresholded_relu_', 'log_softmax', 'glu', 'gumbel_softmax', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 9742ea25f8c6d9..26d81c7e38d1c9 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -299,6 +299,16 @@ def hardtanh(x, min=-1.0, max=1.0, name=None): return out +@inplace_apis_in_dygraph_only +def hardtanh_(x, min=-1.0, max=1.0, name=None): + r""" + Inplace version of ``hardtanh`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`paddle_nn_functional_hardtanh`. + """ + if in_dynamic_mode(): + return _C_ops.hardtanh_(x, min, max) + + def hardsigmoid(x, slope=0.1666667, offset=0.5, name=None): r""" hardsigmoid activation. Calculate the `hardsigmoid` of input `x`. @@ -458,6 +468,16 @@ def leaky_relu(x, negative_slope=0.01, name=None): return out +@inplace_apis_in_dygraph_only +def leaky_relu_(x, negative_slope=0.01, name=None): + r""" + Inplace version of ``leaky_relu`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`paddle_nn_functional_leaky_relu`. + """ + if in_dynamic_mode(): + return _C_ops.leaky_relu_(x, negative_slope) + + def prelu(x, weight, data_format="NCHW", name=None): """ prelu activation. The calculation formula is follows: @@ -1498,6 +1518,16 @@ def thresholded_relu(x, threshold=1.0, name=None): return out +@inplace_apis_in_dygraph_only +def thresholded_relu_(x, threshold=1.0, name=None): + r""" + Inplace version of ``thresholded_relu`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`paddle_nn_functional_thresholded_relu`. + """ + if in_dynamic_mode(): + return _C_ops.thresholded_relu_(x, threshold) + + def log_softmax(x, axis=-1, dtype=None, name=None): r""" This operator implements the log_softmax layer. The calculation process is diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ccd61d7bb2114b..0dc190f6ebb612 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -35,7 +35,9 @@ from .creation import full # noqa: F401 from .creation import full_like # noqa: F401 from .creation import triu # noqa: F401 +from .creation import triu_ # noqa: F401 from .creation import tril # noqa: F401 +from .creation import tril_ # noqa: F401 from .creation import meshgrid # noqa: F401 from .creation import empty # noqa: F401 from .creation import empty_like # noqa: F401 @@ -162,6 +164,7 @@ from .math import cumprod # noqa: F401 from .math import logcumsumexp # noqa: F401 from .math import logit # noqa: F401 +from .math import logit_ # noqa: F401 from .math import exp # noqa: F401 from .math import exp_ # noqa: F401 from .math import expm1 # noqa: F401 @@ -169,6 +172,7 @@ from .math import floor_ # noqa: F401 from .math import increment # noqa: F401 from .math import log # noqa: F401 +from .math import log_ # noqa: F401 from .math import multiplex # noqa: F401 from .math import pow # noqa: F401 from .math import pow_ # noqa: F401 @@ -221,8 +225,11 @@ from .math import logaddexp # noqa: F401 from .math import inverse # noqa: F401 from .math import log2 # noqa: F401 +from .math import log2_ # noqa: F401 from .math import log10 # noqa: F401 +from .math import log10_ # noqa: F401 from .math import log1p # noqa: F401 +from .math import log1p_ # noqa: F401 from .math import erf # noqa: F401 from .math import addmm # noqa: F401 from .math import addmm_ # noqa: F401 @@ -239,9 +246,13 @@ from .math import broadcast_shape # noqa: F401 from .math import conj # noqa: F401 from .math import trunc # noqa: F401 +from .math import trunc_ # noqa: F401 from .math import digamma # noqa: F401 +from .math import digamma_ # noqa: F401 from .math import neg # noqa: F401 +from .math import neg_ # noqa: F401 from .math import lgamma # noqa: F401 +from .math import lgamma_ # noqa: F401 from .math import diagonal # noqa: F401 from .math import acosh # noqa: F401 from .math import acosh_ # noqa: F401 @@ -265,6 +276,7 @@ from .math import outer # noqa: F401 from .math import heaviside # noqa: F401 from .math import frac # noqa: F401 +from .math import frac_ # noqa: F401 from .math import sgn # noqa: F401 from .math import take # noqa: F401 from .math import frexp # noqa: F401 @@ -276,10 +288,12 @@ from .math import vander # noqa: F401 from .math import nextafter # noqa: F401 from .math import i0 # noqa: F401 +from .math import i0_ # noqa: F401 from .math import i0e # noqa: F401 from .math import i1 # noqa: F401 from .math import i1e # noqa: F401 from .math import polygamma # noqa: F401 +from .math import polygamma_ # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -367,6 +381,7 @@ 'cumprod', 'logcumsumexp', 'logit', + 'logit_', 'exp', 'exp_', 'expm1', @@ -375,8 +390,11 @@ 'increment', 'logaddexp', 'log', + 'log_', 'log2', + 'log2_', 'log10', + 'log10_', 'logsumexp', 'multiplex', 'pow', @@ -432,6 +450,7 @@ 'logsumexp', 'inverse', 'log1p', + 'log1p_', 'erf', 'addmm', 'addmm_', @@ -446,19 +465,26 @@ 'broadcast_shape', 'conj', 'neg', + 'neg_', 'lgamma', + 'lgamma_', 'equal', 'equal_all', 'greater_equal', + 'greater_equal_', 'greater_than', + 'greater_than_', 'is_empty', 'less_equal', + 'less_equal_', 'less_than', + 'less_than_', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'not_equal', + 'not_equal_', 'allclose', 'isclose', 'is_tensor', @@ -525,9 +551,12 @@ 'imag', 'is_floating_point', 'digamma', + 'digamma_', 'diagonal', 'trunc', + 'trunc_', 'frac', + 'frac_', 'bitwise_and', 'bitwise_or', 'bitwise_xor', @@ -583,10 +612,12 @@ 'nextafter', 'unflatten', 'i0', + 'i0_', 'i0e', 'i1', 'i1e', 'polygamma', + 'polygamma_', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index d12017ea4b219c..96a2077337d766 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -22,6 +22,7 @@ import paddle from paddle import _C_ops +from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..fluid.data_feeder import ( check_dtype, @@ -1462,6 +1463,17 @@ def tril(x, diagonal=0, name=None): return _tril_triu_op(LayerHelper('tril', **locals())) +@inplace_apis_in_dygraph_only +def tril_(x, diagonal=0, name=None): + r""" + Inplace version of ``tril`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_tril`. + """ + + if in_dynamic_mode(): + return _C_ops.tril_(x, diagonal) + + def triu(x, diagonal=0, name=None): r""" Return the upper triangular part of a matrix (2-D tensor) or batch of matrices @@ -1524,6 +1536,17 @@ def triu(x, diagonal=0, name=None): return _tril_triu_op(LayerHelper('triu', **locals())) +@inplace_apis_in_dygraph_only +def triu_(x, diagonal=0, name=None): + r""" + Inplace version of ``triu`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_triu`. + """ + + if in_dynamic_mode(): + return _C_ops.triu_(x, diagonal) + + def meshgrid(*args, **kwargs): """ diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 42b8d02ead727c..157349e8ff013a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -187,6 +187,17 @@ def log(x, name=None): return out +@inplace_apis_in_dygraph_only +def log_(x, name=None): + r""" + Inplace version of ``log`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_log`. + """ + + if in_dynamic_mode(): + return _C_ops.log_(x) + + def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): """ Scale operator. @@ -1821,6 +1832,16 @@ def trunc(input, name=None): return out +@inplace_apis_in_dygraph_only +def trunc_(input, name=None): + r""" + Inplace version of ``trunc`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_trunc`. + """ + if in_dynamic_mode(): + return _C_ops.trunc_(input) + + def mm(input, mat2, name=None): """ @@ -2877,6 +2898,17 @@ def log1p(x, name=None): return out +@inplace_apis_in_dygraph_only +def log1p_(x, name=None): + r""" + Inplace version of ``log1p`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_log1p`. + """ + + if in_dynamic_mode(): + return _C_ops.log1p_(x) + + def log2(x, name=None): r""" Calculates the log to the base 2 of the given input tensor, element-wise. @@ -2932,6 +2964,17 @@ def log2(x, name=None): return out +@inplace_apis_in_dygraph_only +def log2_(x, name=None): + r""" + Inplace version of ``log2`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_log2`. + """ + + if in_dynamic_mode(): + return _C_ops.log2_(x) + + def log10(x, name=None): r""" Calculates the log to the base 10 of the given input tensor, element-wise. @@ -2987,6 +3030,17 @@ def log10(x, name=None): return out +@inplace_apis_in_dygraph_only +def log10_(x, name=None): + r""" + Inplace version of ``log10`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_log10`. + """ + + if in_dynamic_mode(): + return _C_ops.log10_(x) + + def clip(x, min=None, max=None, name=None): """ This operator clip all elements in input into the range [ min, max ] and return @@ -4385,6 +4439,16 @@ def digamma(x, name=None): return out +@inplace_apis_in_dygraph_only +def digamma_(x, name=None): + r""" + Inplace version of ``digamma`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_digamma`. + """ + if in_dynamic_mode(): + return _C_ops.digamma_(x) + + def lgamma(x, name=None): r""" Calculates the lgamma of the given input tensor, element-wise. @@ -4422,6 +4486,16 @@ def lgamma(x, name=None): return out +@inplace_apis_in_dygraph_only +def lgamma_(x, name=None): + r""" + Inplace version of ``lgamma`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_lgamma`. + """ + if in_dynamic_mode(): + return _C_ops.lgamma_(x) + + def neg(x, name=None): """ This function computes the negative of the Tensor elementwisely. @@ -4449,6 +4523,17 @@ def neg(x, name=None): ) +@inplace_apis_in_dygraph_only +def neg_(x, name=None): + r""" + Inplace version of ``neg`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_neg`. + """ + return x.scale_( + scale=-1.0, bias=0.0, bias_after_scale=True, act=None, name=name + ) + + def atan2(x, y, name=None): r""" Element-wise arctangent of x/y with consideration of the quadrant. @@ -4574,6 +4659,18 @@ def logit(x, eps=None, name=None): return out +@inplace_apis_in_dygraph_only +def logit_(x, eps=None, name=None): + r""" + Inplace version of ``logit`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_logit`. + """ + if eps is None: + eps = 0.0 + if in_dynamic_mode(): + return _C_ops.logit_(x, eps) + + def lerp(x, y, weight, name=None): r""" Does a linear interpolation between x and y based on weight. @@ -5322,6 +5419,29 @@ def frac(x, name=None): return _elementwise_op(LayerHelper('elementwise_sub', **locals())) +@inplace_apis_in_dygraph_only +def frac_(x, name=None): + r""" + Inplace version of ``frac`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_frac`. + """ + + if x.dtype not in [ + paddle.int32, + paddle.int64, + paddle.float32, + paddle.float64, + ]: + raise TypeError( + "The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {}".format( + x.dtype + ) + ) + if in_dynamic_mode(): + y = _C_ops.trunc(x) + return _C_ops.subtract_(x, y) + + def sgn(x, name=None): """ For complex tensor, this API returns a new tensor whose elements have the same angles as the corresponding @@ -5884,6 +6004,17 @@ def i0(x, name=None): return out +@inplace_apis_in_dygraph_only +def i0_(x, name=None): + r""" + Inplace version of ``i0`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_i0`. + """ + + if in_dynamic_mode(): + return _C_ops.i0_(x) + + def i0e(x, name=None): r""" The function used to calculate exponentially scaled modified Bessel function of order 0. @@ -6046,6 +6177,27 @@ def polygamma(x, n, name=None): return out +def polygamma_(x, n, name=None): + r""" + Inplace version of ``polygamma`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_polygamma`. + """ + if not isinstance(n, int): + raise TypeError( + "The input of n must be int type, but received: %s " % (type(n)) + ) + if n < 0: + raise ValueError( + "The input of n must be greater than or equal to 0. But received n = %s" + % (n) + ) + if n == 0: + return digamma_(x) + else: + if in_dynamic_mode(): + return _C_ops.polygamma_(x, n) + + def ldexp(x, y, name=None): """ Compute the result of multiplying x by 2 to the power of y. The equation is: diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index d1dbc00dd55a42..48cc9a9faa8dba 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -697,5 +697,141 @@ def test_type_error(self): paddle.pow_(var, [2]) +class TestDygraphInplaceTriu(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.triu_(var, 0) + + def non_inplace_api_processing(self, var): + return paddle.triu(var, 0) + + +class TestDygraphInplaceTril(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.tril_(var, 0) + + def non_inplace_api_processing(self, var): + return paddle.tril(var, 0) + + +class TestDygraphInplaceLogit(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.logit_(var, 1e-3) + + def non_inplace_api_processing(self, var): + return paddle.logit(var, 1e-3) + + +class TestDygraphInplaceLog(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.log_(var) + + def non_inplace_api_processing(self, var): + return paddle.log(var) + + +class TestDygraphInplaceLog2(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.log2_(var) + + def non_inplace_api_processing(self, var): + return paddle.log2(var) + + +class TestDygraphInplaceLog10(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.log10_(var) + + def non_inplace_api_processing(self, var): + return paddle.log10(var) + + +class TestDygraphInplaceLog1p(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.log1p_(var) + + def non_inplace_api_processing(self, var): + return paddle.log1p(var) + + +class TestDygraphInplaceTrunc(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.trunc_(var) + + def non_inplace_api_processing(self, var): + return paddle.trunc(var) + + +class TestDygraphInplaceDigamma(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.digamma_(var) + + def non_inplace_api_processing(self, var): + return paddle.digamma(var) + + +class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.neg_(var) + + def non_inplace_api_processing(self, var): + return paddle.neg(var) + + +class TestDygraphInplaceLgamma(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.lgamma_(var) + + def non_inplace_api_processing(self, var): + return paddle.lgamma(var) + + +class TestDygraphInplaceFrac(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.frac_(var) + + def non_inplace_api_processing(self, var): + return paddle.frac(var) + + +class TestDygraphInplaceI0(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.i0_(var) + + def non_inplace_api_processing(self, var): + return paddle.i0(var) + + +class TestDygraphInplacePolygamma(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.polygamma_(var, 1) + + def non_inplace_api_processing(self, var): + return paddle.polygamma(var, 1) + + +class TestDygraphInplaceHardTanh(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.nn.functional.hardtanh_(var, -1.0, 1.0) + + def non_inplace_api_processing(self, var): + return paddle.nn.functional.hardtanh(var, -1.0, 1.0) + + +class TestDygraphInplaceLeakyRelu(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.nn.functional.leaky_relu_(var, 0.01) + + def non_inplace_api_processing(self, var): + return paddle.nn.functional.leaky_relu(var, 0.01) + + +class TestDygraphInplaceThresholdedRelu(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.nn.functional.thresholded_relu_(var, 1.0) + + def non_inplace_api_processing(self, var): + return paddle.nn.functional.thresholded_relu(var, 1.0) + + if __name__ == '__main__': unittest.main()