From 6bc9f5855b57107e0a80c7453c651a40ef81e601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=83=E5=AE=B8?= <118902573+tlxd@users.noreply.github.com> Date: Mon, 22 Jul 2024 09:46:16 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Fix=20PIR=20Unittest=20BUAA=20No.27?= =?UTF-8?q?=E3=80=91Fix=20test=5Fslice=5Fvar=20PIR=20mode=20(#66226)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tlxd2 * tlxd2.1 * tlxd2.2 --- .../legacy_test/test_slice_var.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) rename test/{deprecated => }/legacy_test/test_slice_var.py (85%) diff --git a/test/deprecated/legacy_test/test_slice_var.py b/test/legacy_test/test_slice_var.py similarity index 85% rename from test/deprecated/legacy_test/test_slice_var.py rename to test/legacy_test/test_slice_var.py index 1ed9d4dfa9e273..6e73d5bdc566e7 100644 --- a/test/deprecated/legacy_test/test_slice_var.py +++ b/test/legacy_test/test_slice_var.py @@ -15,6 +15,7 @@ import random import unittest +import paddle from paddle import base from paddle.distributed.transpiler.distribute_transpiler import slice_variable @@ -22,14 +23,15 @@ class TestSliceVar(unittest.TestCase): def check_slice_output(self, shapes, expected_sizes, min_size): var_list = [] - program = base.Program() - for shape in shapes: - var = program.global_block().create_var( - name=str(random.randint(10000, 99999)), - persistable=True, - shape=shape, - ) - var_list.append(var) + with paddle.pir_utils.OldIrGuard(): + program = base.Program() + for shape in shapes: + var = program.global_block().create_var( + name=str(random.randint(10000, 99999)), + persistable=True, + shape=shape, + ) + var_list.append(var) blocks = slice_variable(var_list, 10, min_size) all_sizes = [] for s in expected_sizes: