Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support paddle.max output 0D, test=allcase #53242

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));

REGISTER_OPERATOR(
reduce_max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ void max_grad(const Tensor& x,
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : OriginReduceInferMeta
func : ReduceIntArrayAxisInferMeta
kernel :
func : max
backward : max_grad
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ PD_REGISTER_KERNEL(add_n,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}

PD_REGISTER_KERNEL(add_n_array,
Expand All @@ -99,4 +100,5 @@ PD_REGISTER_KERNEL(add_n_array,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/funcs/selected_rows_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::float16>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;

#ifdef PADDLE_WITH_XPU
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/kernels/funcs/unsqueeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,32 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,

inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
const DDim& in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
int output_rank = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_rank = in_dims.size();
std::vector<int64_t> output_shape(output_rank, 0);

// Validity Check: rank range.
PADDLE_ENFORCE_LE(
output_size,
output_rank,
6,
phi::errors::InvalidArgument("The output "
"tensor's rank should be less than 6."));

for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
int cur = axis < 0 ? axis + cur_output_rank + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE_GE(
cur,
0,
phi::errors::InvalidArgument("The insert dimension value should "
"not be less than 0"));
PADDLE_ENFORCE_LE(cur,
cur_output_size,
cur_output_rank,
phi::errors::InvalidArgument(
"The insert dimension value shoule not be larger "
"than the dimension size of input tensor"));
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
for (int i = cur_output_rank; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
Expand All @@ -139,11 +139,11 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
cur_output_rank++;
}

// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/onednn/reduce_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ void ReduceKernel(const Context& dev_ctx,
reduction_p->execute(astream, reduction_args);
astream.wait();

out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
const auto reshape_dims = out->dims().size() != 0
? vectorize<int64_t>(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def unscale_method(self, optimizer):
paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
)
self._found_inf = is_found_inf.numpy()[0]
self._found_inf = int(is_found_inf)


class MixPrecisionScaler:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _ndim_(var):

@property
def _size_(var):
return np.prod(var.shape)
return int(np.prod(var.shape))

@property
def _T_(var):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_LR_state_dict(self):
adam_test.set_dict(opt_state)
self.assertEqual(
adam_test._learning_rate.best_loss,
adam3._learning_rate.best_loss.numpy()[0],
adam3._learning_rate.best_loss,
"best_loss is different before and after set_dict",
)
self.assertEqual(
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_LinearLrWarmup(self):
t = lr()

np.testing.assert_allclose(
t.numpy()[0].item(), right_result[i], rtol=1e-05
t.numpy().item(), right_result[i], rtol=1e-05
)

with self.assertRaises(TypeError):
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_StepDecay(self):
right_result = step_decay(
epoch, learning_rate, step_size, decay_rate
)
fluid_result = scheduler().numpy()[0]
fluid_result = scheduler().numpy().item()
scheduler.epoch()
self.assertAlmostEqual(
right_result,
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_LambdaDecay(self):

for epoch in range(30):
right_result = lambda_decay(epoch, learning_rate, lr_lambda)
fluid_result = scheduler().numpy()[0]
fluid_result = scheduler().numpy().item()
scheduler.epoch()
self.assertAlmostEqual(
right_result,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _test_dygraph(self, place, kwargs):
self.assertEqual(
scheduler.cooldown_counter, scheduler1.cooldown_counter
)
self.assertEqual(scheduler.best.numpy()[0], scheduler1.best)
self.assertEqual(scheduler.best, scheduler1.best)
self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
Expand Down
49 changes: 44 additions & 5 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,19 @@ def test_dygraph_reduce(self):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))

# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return

x = paddle.rand([3, 5])
# 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
Expand All @@ -240,6 +242,21 @@ def test_dygraph_reduce(self):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5])

# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
out.retain_grads()
out.backward()

self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])

paddle.enable_static()

def test_static_reduce(self):
Expand Down Expand Up @@ -284,16 +301,19 @@ def test_static_reduce(self):
np.testing.assert_allclose(res[2], np.array(1.0))
np.testing.assert_allclose(res[3], np.array(1.0))

# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return

# 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
Expand All @@ -309,6 +329,25 @@ def test_static_reduce(self):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (3, 5))

# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
paddle.static.append_backward(out)

fetch_list = [out]
if block.has_var(x.grad_name):
fetch_list.extend([out.grad_name, x.grad_name])

res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
if len(res) > 1:
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (5,))

paddle.disable_static()


Expand Down
9 changes: 7 additions & 2 deletions python/paddle/hapi/progressbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@ def convert_uint16_to_float(in_list):

for i, (k, val) in enumerate(values):
if k == "loss":
val = val if isinstance(val, (list, np.ndarray)) else [val]
if isinstance(val[0], np.uint16):
if isinstance(val, list):
scalar_val = val[0]
elif isinstance(val, np.ndarray):
scalar_val = val.item()
else:
scalar_val = val
if isinstance(scalar_val, np.uint16):
values[i] = ("loss", list(convert_uint16_to_float(val)))

if current_num:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _dygraph_clip(self, params_grads):
global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var)
max_global_norm = paddle.full(
shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm
shape=[], dtype=global_norm_var.dtype, fill_value=self.clip_norm
)

need_clip = False
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/quant/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
s_attr = ParamAttr(
name=self._scale_name, initializer=Constant(1.0), trainable=True
)
self.s = self.create_parameter(shape=[1], attr=s_attr, dtype='float32')
self.s = self.create_parameter(shape=[], attr=s_attr, dtype='float32')
self.s.stop_gradient = False

if not self.symmetric:
Expand All @@ -189,7 +189,7 @@ def __init__(
name=self._beta_name, initializer=Constant(0.0), trainable=True
)
self.beta = self.create_parameter(
shape=[1], attr=beta_attr, dtype='float32'
shape=[], attr=beta_attr, dtype='float32'
)
self.beta.stop_gradient = False

Expand Down
5 changes: 1 addition & 4 deletions test/autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
# Finite Difference Utils
##########################################################
def _product(t):
if isinstance(t, int):
return t
else:
return np.product(t)
return int(np.product(t))


def _get_item(t, idx):
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/seq2seq_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def beam_search(self, inputs):
parent_ids = []

for step_idx in range(paddle.to_tensor(self.beam_max_step_num)):
if paddle.sum(1 - beam_finished).numpy()[0] == 0:
if paddle.sum(1 - beam_finished) == 0:
break
step_input = self._merge_batch_beams(step_input)
new_dec_hidden, new_dec_cell = [], []
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def for_in_range(x):
z = paddle.tensor.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]):
for i in range(x.numpy().item()):
z = z + i
return z

Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def train(args, to_static):

model.train()
avg_cost, prediction, acc = model(doc, label)
loss_data.append(avg_cost.numpy()[0])
loss_data.append(float(avg_cost))

avg_cost.backward()
sgd_optimizer.minimize(avg_cost)
Expand All @@ -358,7 +358,7 @@ def train(args, to_static):
"step: %d, ave loss: %f, speed: %f steps/s"
% (
batch_id,
avg_cost.numpy()[0],
float(avg_cost),
args.log_step / used_time,
)
)
Expand Down