Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] optimize print function convertor to display Tensor at compile time #48672

Merged
merged 13 commits into from
Dec 5, 2022
170 changes: 54 additions & 116 deletions python/paddle/fluid/tests/unittests/dygraph_to_static/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,149 +16,82 @@

import numpy

import paddle
import paddle.fluid as fluid
from paddle.jit import ProgramTranslator
from paddle.jit.api import declarative
from paddle.jit import ProgramTranslator, to_static

program_translator = ProgramTranslator()


# 1. print VarBase
@declarative
# 1. print Tensor
@to_static
def dyfunc_print_variable(x):
"""
PY2:
Print(dest=None, values=[Name(id='x_v', annotation=None, type_comment=None)], nl=True)],
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[]))
"""
# NOTE: transform to static code, var name will be changed
x_v = fluid.dygraph.to_variable(x)
print(x_v)
x_t = paddle.to_tensor(x)
print(x_t)


# 2. print ndarray
@declarative
@to_static
def dyfunc_print_ndarray(x):
"""
PY2:
Print(dest=None, values=[Name(id='x', annotation=None, type_comment=None)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x', annotation=None, type_comment=None)],
keywords=[]))
"""
print(x)


# 3. print VarBase with format
@declarative
# 3. print Tensor with format
@to_static
def dyfunc_print_with_format(x):
"""
PY2:
Print(dest=None,
values=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
Call(
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
args=[Name(id='x_v', annotation=None, type_comment=None)],
keywords=[])],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: {}".format(x_v))


# 4. print VarBase with format 2
@declarative
x_t = paddle.to_tensor(x)
print("PrintTensor: {}".format(x_t))


# 4. print Tensor with format 2
@to_static
def dyfunc_print_with_format2(x):
"""
PY2:
Print(dest=None,
values=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
nl=True)
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
op=Mod,
right=Name(id='x_v', annotation=None, type_comment=None))],
keywords=[]))
"""
x_v = fluid.dygraph.to_variable(x)
print("PrintVariable: %s" % (x_v))


# 5. print VarBase in control flow1
@declarative
x_t = paddle.to_tensor(x)
print("PrintTensor: %s" % (x_t))


# 5. print Tensor in control flow1
@to_static
def dyfunc_print_with_ifelse(x):
x_v = fluid.dygraph.to_variable(x)
if len(x_v.shape) > 1:
print(x_v)
x_t = paddle.to_tensor(x)
if len(x_t.shape) > 1:
print(x_t)
else:
print(x_v)
print(x_t)


# 6. print mutiple VarBases
@declarative
def dyfunc_print_multi_vars(x):
"""
# NOTE: y_v type is error before cur PR in this case
Assign(targets=[Name(id='y_v', annotation=None, type_comment=None)],
value=BinOp(left=Name(id='x_v', annotation=None, type_comment=None), op=Mult, right=Constant(value=2, kind=None)))
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v)
print(y_v)
# 6. print multiple Tensor
@to_static
def dyfunc_print_multi_tensor(x):
x_t = paddle.to_tensor(x)
y_t = x_t * 2
print(x_t)
print(y_t)


# 7. print continue VarBase
@declarative
# 7. print continue Tensor
@to_static
def dyfunc_print_continue_vars(x):
"""
PY3:
Expr(
value=Call(func=Name(id='print', annotation=None, type_comment=None),
args=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)],
keywords=[]))
PY2:
Print(dest=None,
values=[
Tuple(
elts=[Name(id='x_v', annotation=None, type_comment=None),
Name(id='y_v', annotation=None, type_comment=None)])],
nl=True)
"""
x_v = fluid.dygraph.to_variable(x)
y_v = x_v * 2
print(x_v, y_v)
x_t = paddle.to_tensor(x)
y_t = x_t * 2
print(x_t, y_t)


# 8. print with kwargs
@to_static
def dyfunc_print_with_kwargs(x):
x_t = paddle.to_tensor(x)
print("Tensor", x_t, end='\n\n', sep=': ')


class TestPrintBase(unittest.TestCase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = (
fluid.CUDAPlace(0)
if fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.set_test_func()

Expand Down Expand Up @@ -207,15 +140,20 @@ def set_test_func(self):
self.dygraph_func = dyfunc_print_with_ifelse


class TestPrintMultipleVar(TestPrintVariable):
class TestPrintMultipleTensor(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_multi_vars
self.dygraph_func = dyfunc_print_multi_tensor


class TestPrintContinueVar(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_continue_vars


class TestPrintWithKwargs(TestPrintVariable):
def set_test_func(self):
self.dygraph_func = dyfunc_print_with_kwargs


if __name__ == '__main__':
unittest.main()
1 change: 0 additions & 1 deletion python/paddle/jit/dy2static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from .convert_operators import convert_logical_not as Not # noqa: F401
from .convert_operators import convert_logical_or as Or # noqa: F401
from .convert_operators import convert_pop as Pop # noqa: F401
from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
Expand Down
4 changes: 0 additions & 4 deletions python/paddle/jit/dy2static/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@
from .loop_transformer import (
LoopTransformer,
)
from .print_transformer import (
PrintTransformer,
)
from .return_transformer import (
ReturnTransformer,
)
Expand Down Expand Up @@ -135,7 +132,6 @@ def transfer_from_node_type(self, node_wrapper):
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
PrintTransformer, # print statement
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/dy2static/call_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _no_need_convert_call(self, node):
'zip',
'range',
'enumerate',
'print',
}
is_builtin = eval("is_builtin({})".format(func_str))
need_convert = func_str in need_convert_builtin_func_list
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/jit/dy2static/convert_call_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
convert_zip,
convert_range,
convert_enumerate,
convert_print,
)

from paddle.jit.dy2static.logging_utils import (
Expand Down Expand Up @@ -215,6 +216,9 @@ def dyfunc(x):
if is_builtin(func, "enumerate"):
return convert_enumerate

if is_builtin(func, "print"):
return convert_print

if is_builtin(func) or is_unsupported(func):
return func

Expand Down
16 changes: 7 additions & 9 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,17 +738,15 @@ def convert_assert(cond, message=""):
assert cond, message


def convert_print(*args):
def convert_print(*objects, sep=' ', end='\n', file=None, flush=False):
"""
A function representing Python ``print`` statement. Note: this is a basic
python function so we haven't handle sep, end, file and flush parameters of
python function.
A function representing Python ``print`` function. It will print all arguments
at compile time and only print the Tensor values at runtime.
"""
for var in args:
if isinstance(var, Variable):
var = Print(var)
else:
print(var)
for obj in objects:
if isinstance(obj, Variable):
Print(obj)
print(*objects, sep=sep, end=end, file=file, flush=flush)


def convert_pop(target, *args):
Expand Down
59 changes: 0 additions & 59 deletions python/paddle/jit/dy2static/print_transformer.py

This file was deleted.