Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable __setitem__ in static mode & add API paddle.static.setitem with dy2st strategy #53682

Merged
merged 33 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3dfa3cc
add paddle.static.setitem
zoooo0820 Apr 20, 2023
c7730e5
add some help doc
zoooo0820 Apr 20, 2023
5df24af
support setitem
Aurelius84 Apr 23, 2023
9ada910
support machanism
Aurelius84 Apr 23, 2023
44944b2
add more unittest
Aurelius84 Apr 24, 2023
aa0247f
remove usless code
Aurelius84 Apr 24, 2023
dc99e44
merge develop
zoooo0820 May 10, 2023
788f8ac
raise error in static setitem
zoooo0820 May 10, 2023
5f9f919
add static api
zoooo0820 May 11, 2023
c4853e0
fix d2s UT
zoooo0820 May 11, 2023
4218a96
remove static only for both-used code
zoooo0820 May 16, 2023
69da71a
fix bool set_value in static, fix set_value_op UT
zoooo0820 May 16, 2023
825f058
fix unittests
zoooo0820 May 18, 2023
3d60e1a
[May case some error]: remove inplace-version check
zoooo0820 May 18, 2023
9b6d777
add two test case for dy2st
zoooo0820 May 18, 2023
bb0c13e
fix function in vision
zoooo0820 May 19, 2023
f7e1641
fix dy2st setitem support, refine UT case
zoooo0820 May 26, 2023
84a1bff
merge dev
zoooo0820 May 26, 2023
51345e1
fix slice in static_mode
NotHaozi Jul 10, 2023
d8ad56b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
NotHaozi Jul 10, 2023
dfd777f
add ParametersMap
NotHaozi Jul 10, 2023
8d8eb6a
remove pop
NotHaozi Jul 10, 2023
e8ecda1
modify place
NotHaozi Jul 11, 2023
df0debc
[fix]: variable is also a tensor
NotHaozi Jul 11, 2023
0d706e0
merge develop & solve conflict
zoooo0820 Jul 11, 2023
3013d67
rewrite some ut & remove slicetransformer in dy2st
zoooo0820 Jul 11, 2023
0caea01
merge new strategy of setitem dy2st
zoooo0820 Jul 11, 2023
c685130
solve error in static-mode
zoooo0820 Jul 11, 2023
65d2c72
fix ut
zoooo0820 Jul 12, 2023
2da88ee
Merge branch 'develop' into fix_d2s
zoooo0820 Jul 12, 2023
7b5c1c4
return a result for set_array_write
zoooo0820 Jul 13, 2023
a846b79
fix test_set_value_op_xpu
zoooo0820 Jul 13, 2023
60686fd
code is different in dynamic / static mode
zoooo0820 Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/paddle/distribution/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def log_prob(self, value):
logits, value = paddle.broadcast_tensors(
[paddle.log(self.probs), value]
)
logits[(value == 0) & (paddle.isinf(logits))] = 0
logits = paddle.static.setitem(
logits, (value == 0) & (paddle.isinf(logits)), 0
)

return (
paddle.lgamma(value.sum(-1) + 1)
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,9 @@ def __getitem__(self, item):
return _getitem_impl_(self, item)

def __setitem__(self, item, value):
return _setitem_impl_(self, item, value)
raise RuntimeError(
"In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)"
)

def get_value(self, scope=None):
"""
Expand Down
25 changes: 17 additions & 8 deletions python/paddle/fluid/tests/unittests/test_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_int32(self):
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.int32)
patch = np.array([41, 42]).astype(np.int32)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = np.ones([3, 4], dtype=np.int32)
x_output = x_input.copy()
Expand Down Expand Up @@ -110,10 +111,12 @@ def test_int64(self):
patch = np.array(
[np.iinfo(np.int64).max, np.iinfo(np.int64).min]
).astype(np.int64)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = np.ones([3, 4], dtype=np.int64)
x_output = x_input.copy()

x_output[:1, :2] = patch

self.fetch_list = [x.name]
Expand Down Expand Up @@ -142,7 +145,8 @@ def test_float32(self):
patch = np.array(
[np.finfo(np.float32).max, np.finfo(np.float32).min]
).astype(np.float32)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = np.ones([3, 4], dtype=np.float32)
x_output = x_input.copy()
Expand Down Expand Up @@ -171,7 +175,8 @@ def test_float64(self):
patch = np.array(
[np.finfo(np.float64).max, np.finfo(np.float64).min]
).astype(np.float64)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = np.ones([3, 4], dtype=np.float64)
x_output = x_input.copy()
Expand Down Expand Up @@ -200,7 +205,8 @@ def test_float16(self):
patch = np.array(
[np.finfo(np.float16).max, np.finfo(np.float16).min]
).astype(np.float16)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = np.ones([3, 4], dtype=np.float16)
x_output = x_input.copy()
Expand All @@ -227,7 +233,8 @@ def test_bool(self):
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.bool)
patch = np.array([True, False])
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = np.ones([3, 4], dtype=bool)
x_output = x_input.copy()
Expand Down Expand Up @@ -257,7 +264,8 @@ def test_complex64(self):
paddle.ones([3, 4], dtype=paddle.float32),
)
patch = np.array([42.1 + 42.1j, 42.2 + 42.2j]).astype(np.complex64)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex64)
x_output = x_input.copy()
Expand All @@ -282,7 +290,8 @@ def test_complex128(self):
np.finfo(np.float64).min + 1j * np.finfo(np.float64).max,
]
).astype(np.complex128)
x[:1, :2] = patch
index = (slice(None, 1), slice(None, 2))
x = paddle.static.setitem(x, index, patch)

x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex128)
x_output = x_input.copy()
Expand Down
Loading