From 3f49bb6f3c7d35774249c91a720ff0aa9e35035c Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 20 May 2022 04:31:24 +0000 Subject: [PATCH 1/2] Add the ability to trace TensorList in-place ops Summary: To trace through c10d::all_gather, AOT needs to support TensorList in-place ops. Test Plan: WIP. --- functorch/_src/python_key.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/functorch/_src/python_key.py b/functorch/_src/python_key.py index 68527650d..2a3403d05 100644 --- a/functorch/_src/python_key.py +++ b/functorch/_src/python_key.py @@ -89,8 +89,15 @@ def unwrap_proxy(e): # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": - args[0].proxy = proxy_out - proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) + if isinstance(args[0], torch.Tensor): + args[0].proxy = proxy_out + proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) + elif isinstance(args[0], list): + for i in range(len(args[0])): + args[0][i].proxy = proxy_out[i] + proxy_out[i].node.meta['tensor_meta'] = _extract_tensor_metadata(args[0][i]) + else: + assert f"unknown types: {args[0].type()}" with no_dispatch(): real_out = func_overload(*args, **kwargs) From 22ca1d9aae2f228b8d015fdc75f4df0565326edd Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 24 May 2022 17:47:28 +0000 Subject: [PATCH 2/2] Some hacks to make it work --- functorch/_src/partitioners.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 5fac101a5..690ba636b 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -147,10 +147,15 @@ def default_partition( continue # Since we can't save tuple of tensor values, we need to flatten out what we're saving if 'tensor_meta' not in node.meta and node.op == 'call_function': + # print("node:", node) users = node.users - assert all([user.target == operator.getitem for user in users]) + # print("operator:", operator.getitem) + # for user in users: + # print("user:",user) + # assert all([user.target == operator.getitem for user in users]) for user in users: - saved_values.append(user) + if user.target == operator.getitem: + saved_values.append(user) else: saved_values.append(node) saved_values = list(set(saved_values))