Skip to content

Commit

Permalink
【Fix PIR Unittest BUAA No.27】Fix test_slice_var PIR mode (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#66226)

* tlxd2

* tlxd2.1

* tlxd2.2
  • Loading branch information
tlxd authored and lixcli committed Jul 22, 2024
1 parent 4fb3f10 commit 6bc9f58
Showing 1 changed file with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
import random
import unittest

import paddle
from paddle import base
from paddle.distributed.transpiler.distribute_transpiler import slice_variable


class TestSliceVar(unittest.TestCase):
def check_slice_output(self, shapes, expected_sizes, min_size):
var_list = []
program = base.Program()
for shape in shapes:
var = program.global_block().create_var(
name=str(random.randint(10000, 99999)),
persistable=True,
shape=shape,
)
var_list.append(var)
with paddle.pir_utils.OldIrGuard():
program = base.Program()
for shape in shapes:
var = program.global_block().create_var(
name=str(random.randint(10000, 99999)),
persistable=True,
shape=shape,
)
var_list.append(var)
blocks = slice_variable(var_list, 10, min_size)
all_sizes = []
for s in expected_sizes:
Expand Down

0 comments on commit 6bc9f58

Please sign in to comment.