From 5d21526fb4cd44d8943a0a094690bec3be0d115b Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Mon, 8 Nov 2021 00:25:05 +0800 Subject: [PATCH 1/9] add nn.AdaptiveLogSoftmaxWithLoss 0.1 --- .../test_adaptive_log_softmax_with_loss.py | 151 +++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 248 +++++++++++++++++- 4 files changed, 401 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py new file mode 100644 index 00000000000000..dfb1dbced885ea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from paddle import nn +from paddle.nn import functional as F + +np.random.seed(10) +paddle.seed(10) + + +class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): + def test_adaptive_log_softmax(self): + # args validation + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 15], div_value=2.) + + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 10], div_value=2.) + + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 25], div_value=2.) + + with self.assertRaisesRegex(ValueError, + "cutoffs should be a sequence of unique,"): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.) + + # not raise + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.) + + # input shapes + with self.assertRaisesRegex( + RuntimeError, r"Input and target should have the same size"): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.) + x = paddle.randn((2, 16)) + y = paddle.to_tensor([0, 5, 10]) + asfm(x, y) + + # out-of-bound targets + with self.assertRaisesRegex(RuntimeError, + r"Target values should be in"): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.) + x = paddle.randn((128, 16)) + y = paddle.randint(low=21, high=200, shape=[128]) + asfm(x, y) + + # cluster sizes + asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[128]) + # x = paddle.randn((3, 16)) + # y = paddle.to_tensor((0, 17)) + + self.assertEqual( + asfm.head.weight.shape, + [16, 5 + 3]) # 5 targets in head, 3 clusters, dimensionality 16 + self.assertEqual(asfm.tail[0][1].weight.shape, + [8, 5]) # 5 targets in this cluster, dimensionality 8 + self.assertEqual(asfm.tail[1][1].weight.shape, [4, 5]) + self.assertEqual(asfm.tail[2][1].weight.shape, [2, 5]) + + self.assertEqual(asfm(x, y).output.shape, [128]) + + # log_probs actually returns log_proba + asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.) + x = paddle.randn((4, 8)) + logprob_out = asfm.log_prob(x) + np.testing.assert_array_almost_equal( + paddle.exp(logprob_out).sum(1), paddle.ones([4])) + # if_equal=(paddle.abs(paddle.exp(logprob_out).sum(1)-paddle.ones([4]))`__. + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + We highly recommend taking a look at the original paper for more details. + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + Shape: + - input: :math:`(N, \texttt{in\_features})` + - target: :math:`(N)` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` + - output2: ``Scalar`` + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4., + head_bias: bool = False, + ) -> None: + super(AdaptiveLogSoftmaxWithLoss, self).__init__() + + cutoffs = list(cutoffs) + + if (cutoffs != sorted(cutoffs)) \ + or (min(cutoffs) <= 0) \ + or (max(cutoffs) > (n_classes - 1)) \ + or (len(set(cutoffs)) != len(cutoffs)) \ + or any([int(c) != c for c in cutoffs]): + + raise ValueError("cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1") + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear(self.in_features, self.head_size, bias_attr=self.head_bias) + self.tail = LayerList() + + for i in range(self.n_clusters): + + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias_attr=False), + Linear(hsz, osz, bias_attr=False), + ) + + self.tail.append(projection) + + def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: + if input.shape[0] != target.shape[0]: + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + + used_rows = 0 + batch_size = target.shape[0] + + output = paddle.zeros([batch_size]) + gather_inds = paddle.empty([batch_size]) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx).logical_and(target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + if i == 0: + scatter_output = paddle.scatter_nd(row_indices.unsqueeze(1), target.masked_select(target_mask), gather_inds.shape) + # gather_inds = gather_inds * (scatter_output == 0) + scatter_output + gather_inds=scatter_output + else: + relative_target = target.masked_select(target_mask) - low_idx + input_subset = input.index_select(row_indices, 0) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + scatter_output = paddle.scatter_nd(row_indices.unsqueeze(1), paddle.ones(row_indices.shape)*cluster_index, gather_inds.shape) + gather_inds = (gather_inds * (scatter_output != cluster_index) + scatter_output).astype(paddle.int64) + + cluster_logprob = F.log_softmax(cluster_output, axis=1) + local_logprob = (F.one_hot(relative_target, cluster_logprob.shape[-1]) * cluster_logprob).sum(1).unsqueeze(1) + scatter_output = paddle.scatter_nd(row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape) + output = output * (scatter_output == 0) + scatter_output + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError("Target values should be in [0, {}], " + "but values in range [{}, {}] " + "were found. ".format(self.n_classes - 1, + target.min().item(), + target.max().item())) + + head_output = self.head(input) + head_logprob = F.log_softmax(head_output, axis=1) + output += (paddle.nn.functional.one_hot(gather_inds, head_logprob.shape[1]) * head_logprob).sum(1) + loss = (-output).mean() + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """ Given input tensor, and output of `self.head`, + compute the log of the full distribution """ + + out = paddle.empty((head_output.shape[0], self.n_classes)) + head_logprob = F.log_softmax(head_output, axis=1) + + out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = F.log_softmax(cluster_output, axis=1) + output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + r""" Computes log probabilities for all :math:`\texttt{n\_classes}` + Args: + input (Tensor): a minibatch of examples + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + """ + + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + r""" This is equivalent to `self.log_pob(input).argmax(axis=1)`, + but is more efficient in some cases. + Args: + input (Tensor): a minibatch of examples + Returns: + output (Tensor): a class with the highest probability for each example + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + + head_output = self.head(input) + output = paddle.argmax(head_output, axis=1).cast('float32') + not_in_shortlist = (output >= self.shortlist_size) + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return paddle.argmax(log_prob, axis=1) + + else: + log_prob = self._get_full_log_prob(input[not_in_shortlist], + head_output[not_in_shortlist]) + output[not_in_shortlist] = paddle.argmax(log_prob, axis=1).cast('float32') + return output class BCEWithLogitsLoss(Layer): From dd3c94a7e6c28af0fe1b7ade59ecf31c5986c302 Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Mon, 8 Nov 2021 01:55:05 +0800 Subject: [PATCH 2/9] add nn.AdaptiveLogSoftmaxWithLoss v0.2 --- python/paddle/nn/layer/__init__.py | 2 +- python/paddle/nn/layer/loss.py | 54 +++++++++++++++++++----------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index d64fedde9be579..994e374574b111 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -56,7 +56,7 @@ from .pooling import AdaptiveMaxPool2D # noqa: F401 from .pooling import AdaptiveMaxPool3D # noqa: F401 from .pooling import MaxUnPool2D # noqa: F401 -from .loss import AdaptiveLogSoftmaxWithLoss # noqa: F401 +from .loss import AdaptiveLogSoftmaxWithLoss # noqa: F401 from .conv import Conv1D # noqa: F401 from .conv import Conv2D # noqa: F401 from .conv import Conv3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3e5801e97c9ff3..da29ca34604df7 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -29,6 +29,7 @@ __all__ = [] _ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) + class AdaptiveLogSoftmaxWithLoss(Layer): r"""Efficient softmax approximation as described in `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, @@ -101,15 +102,13 @@ class AdaptiveLogSoftmaxWithLoss(Layer): .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law """ - def __init__( self, in_features: int, n_classes: int, cutoffs: Sequence[int], - div_value: float = 4., - head_bias: bool = False, - ) -> None: + div_value: float=4., + head_bias: bool=False, ) -> None: super(AdaptiveLogSoftmaxWithLoss, self).__init__() cutoffs = list(cutoffs) @@ -134,18 +133,20 @@ def __init__( self.n_clusters = len(self.cutoffs) - 1 self.head_size = self.shortlist_size + self.n_clusters - self.head = Linear(self.in_features, self.head_size, bias_attr=self.head_bias) + self.head = Linear( + self.in_features, self.head_size, bias_attr=self.head_bias) self.tail = LayerList() for i in range(self.n_clusters): - hsz = int(self.in_features // (self.div_value ** (i + 1))) + hsz = int(self.in_features // (self.div_value**(i + 1))) osz = self.cutoffs[i + 1] - self.cutoffs[i] projection = Sequential( - Linear(self.in_features, hsz, bias_attr=False), - Linear(hsz, osz, bias_attr=False), - ) + Linear( + self.in_features, hsz, bias_attr=False), + Linear( + hsz, osz, bias_attr=False), ) self.tail.append(projection) @@ -172,9 +173,11 @@ def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: if row_indices.numel() == 0: continue if i == 0: - scatter_output = paddle.scatter_nd(row_indices.unsqueeze(1), target.masked_select(target_mask), gather_inds.shape) + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + target.masked_select(target_mask), gather_inds.shape) # gather_inds = gather_inds * (scatter_output == 0) + scatter_output - gather_inds=scatter_output + gather_inds = scatter_output else: relative_target = target.masked_select(target_mask) - low_idx input_subset = input.index_select(row_indices, 0) @@ -182,12 +185,20 @@ def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: cluster_output = self.tail[i - 1](input_subset) cluster_index = self.shortlist_size + i - 1 - scatter_output = paddle.scatter_nd(row_indices.unsqueeze(1), paddle.ones(row_indices.shape)*cluster_index, gather_inds.shape) - gather_inds = (gather_inds * (scatter_output != cluster_index) + scatter_output).astype(paddle.int64) + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + paddle.ones(row_indices.shape) * cluster_index, + gather_inds.shape) + gather_inds = (gather_inds * (scatter_output != cluster_index) + + scatter_output).astype(paddle.int64) cluster_logprob = F.log_softmax(cluster_output, axis=1) - local_logprob = (F.one_hot(relative_target, cluster_logprob.shape[-1]) * cluster_logprob).sum(1).unsqueeze(1) - scatter_output = paddle.scatter_nd(row_indices.unsqueeze(1), local_logprob.squeeze(1), output.shape) + local_logprob = (F.one_hot(relative_target, + cluster_logprob.shape[-1]) * + cluster_logprob).sum(1).unsqueeze(1) + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + local_logprob.squeeze(1), output.shape) output = output * (scatter_output == 0) + scatter_output used_rows += row_indices.numel() @@ -201,7 +212,8 @@ def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: head_output = self.head(input) head_logprob = F.log_softmax(head_output, axis=1) - output += (paddle.nn.functional.one_hot(gather_inds, head_logprob.shape[1]) * head_logprob).sum(1) + output += (paddle.nn.functional.one_hot( + gather_inds, head_logprob.shape[1]) * head_logprob).sum(1) loss = (-output).mean() return _ASMoutput(output, loss) @@ -215,10 +227,13 @@ def _get_full_log_prob(self, input, head_output): out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] - for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + for i, (start_idx, + stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): cluster_output = self.tail[i](input) cluster_logprob = F.log_softmax(cluster_output, axis=1) - output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) + output_logprob = cluster_logprob + head_logprob[:, + self.shortlist_size + + i].unsqueeze(1) out[:, start_idx:stop_idx] = output_logprob @@ -267,7 +282,8 @@ def predict(self, input: Tensor) -> Tensor: else: log_prob = self._get_full_log_prob(input[not_in_shortlist], head_output[not_in_shortlist]) - output[not_in_shortlist] = paddle.argmax(log_prob, axis=1).cast('float32') + output[not_in_shortlist] = paddle.argmax( + log_prob, axis=1).cast('float32') return output From 805a44528359397c5489811dc83f1001984e2b9d Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Thu, 11 Nov 2021 23:01:05 +0800 Subject: [PATCH 3/9] add nn.AdaptiveLogSoftmaxWithLoss v0.3 --- .../test_adaptive_log_softmax_with_loss.py | 18 +++++------------- python/paddle/nn/layer/loss.py | 6 +++++- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index dfb1dbced885ea..d85121f435c28f 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -23,7 +23,7 @@ class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): - def test_adaptive_log_softmax(self): + def test_error(self): # args validation with self.assertRaises(ValueError): _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 15], div_value=2.) @@ -37,10 +37,10 @@ def test_adaptive_log_softmax(self): with self.assertRaisesRegex(ValueError, "cutoffs should be a sequence of unique,"): _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.) - # not raise _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.) + def test_shape(self): # input shapes with self.assertRaisesRegex( RuntimeError, r"Input and target should have the same size"): @@ -59,12 +59,11 @@ def test_adaptive_log_softmax(self): y = paddle.randint(low=21, high=200, shape=[128]) asfm(x, y) + def test_cluster(self): # cluster sizes asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) x = paddle.randn((128, 16)) y = paddle.randint(low=0, high=20, shape=[128]) - # x = paddle.randn((3, 16)) - # y = paddle.to_tensor((0, 17)) self.assertEqual( asfm.head.weight.shape, @@ -76,33 +75,26 @@ def test_adaptive_log_softmax(self): self.assertEqual(asfm(x, y).output.shape, [128]) + def test_log_probs(self): # log_probs actually returns log_proba asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.) x = paddle.randn((4, 8)) logprob_out = asfm.log_prob(x) np.testing.assert_array_almost_equal( paddle.exp(logprob_out).sum(1), paddle.ones([4])) - # if_equal=(paddle.abs(paddle.exp(logprob_out).sum(1)-paddle.ones([4]))`__. From 05732ad8df8a45edc8e4789b7e64b6bcd994909f Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Fri, 19 Nov 2021 10:10:47 +0800 Subject: [PATCH 4/9] add zeropad2d v0.5 --- .../test_adaptive_log_softmax_with_loss.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index d85121f435c28f..5ba6d124f0f6db 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -138,6 +138,30 @@ def test_correct(self): np.testing.assert_array_almost_equal( out, asfm.log_prob(x).argmax(axis=1)) + def linear_ref(self, x, weight, bias): + x = x.numpy() if isinstance(x, paddle.Tensor) else x + weight = weight.numpy() if isinstance(weight, paddle.Tensor) else weight + bias = bias.numpy() if isinstance(bias, paddle.Tensor) else bias + return np.matmul(x, weight) + bias + + def test_approx_to_lsfm(self): + """ + Test error between AdaptiveLogSoftmaxWithLoss and log_softmax less then 3%, according to https://arxiv.org/abs/1609.04309 + """ + x = paddle.abs(paddle.randn((64, 8))) + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [6, 8], div_value=2., head_bias=True) + + head_weight = asfm.head.weight.detach() + head_bias = asfm.head.bias.detach() + ref_head_output = self.linear_ref(x, head_weight, head_bias) + + out = asfm.log_prob(x).argmax(axis=1) + ref = F.log_softmax( + paddle.to_tensor( + ref_head_output, dtype='float32')).argmax(axis=1) + self.assertTrue(out[out != ref].shape[0] < out.shape[0] / 30) + if __name__ == "__main__": unittest.main() From 0dfaaffe8995141b327e29fdb16be363e2f61a0f Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Fri, 26 Nov 2021 10:06:10 +0800 Subject: [PATCH 5/9] add asfm v0.4 rerun ci --- .../tests/unittests/test_adaptive_log_softmax_with_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index 5ba6d124f0f6db..2bc95070da2506 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -146,7 +146,7 @@ def linear_ref(self, x, weight, bias): def test_approx_to_lsfm(self): """ - Test error between AdaptiveLogSoftmaxWithLoss and log_softmax less then 3%, according to https://arxiv.org/abs/1609.04309 + Test error between AdaptiveLogSoftmaxWithLoss and log_softmax less then 3%, according to http://arxiv.org/abs/1609.04309 """ x = paddle.abs(paddle.randn((64, 8))) asfm = nn.AdaptiveLogSoftmaxWithLoss( From e8b8857692a9a2b2dd1b5cd6a119f28b9e5c6c32 Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Fri, 26 Nov 2021 11:12:03 +0800 Subject: [PATCH 6/9] add asfm v0.4 rerun ci --- .../tests/unittests/test_adaptive_log_softmax_with_loss.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index 2bc95070da2506..e9f401d84984fc 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -18,9 +18,6 @@ from paddle import nn from paddle.nn import functional as F -np.random.seed(10) -paddle.seed(10) - class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): def test_error(self): From 6bb989c8dc7d8774527715ea5dad53c8a79ee8cd Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Fri, 26 Nov 2021 12:51:34 +0800 Subject: [PATCH 7/9] add asfm v0.4 rerun ci --- .../unittests/test_adaptive_log_softmax_with_loss.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index e9f401d84984fc..e73650a2aef569 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -20,6 +20,11 @@ class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): + def setUp(self): + paddle.disable_static() + paddle.seed(10) + np.random.seed(10) + def test_error(self): # args validation with self.assertRaises(ValueError): @@ -127,9 +132,9 @@ def test_correct(self): x[32:, asfm.shortlist_size:] *= 0. asfm.head.weight.detach()[:asfm.shortlist_size, - asfm.shortlist_size:] *= 0. + asfm.shortlist_size:] *= 0. asfm.head.weight.detach()[asfm.shortlist_size:, : - asfm.shortlist_size] *= 0. + asfm.shortlist_size] *= 0. out = asfm.predict(x) np.testing.assert_array_almost_equal( From c934522a989a11186ba651c1cc2ad10223530fba Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Mon, 29 Nov 2021 20:42:15 +0800 Subject: [PATCH 8/9] add asfm v0.4 rerun ci --- .../tests/unittests/test_adaptive_log_softmax_with_loss.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index e73650a2aef569..d68052978046e3 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -22,8 +22,6 @@ class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): def setUp(self): paddle.disable_static() - paddle.seed(10) - np.random.seed(10) def test_error(self): # args validation From 157c2462e884026ead7e768fa9284f9e3952651d Mon Sep 17 00:00:00 2001 From: gsq7474741 Date: Tue, 7 Dec 2021 16:42:50 +0800 Subject: [PATCH 9/9] add asfm v0.5 --- .../test_adaptive_log_softmax_with_loss.py | 28 ++----------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py index d68052978046e3..07f8d6a542e858 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -130,38 +130,14 @@ def test_correct(self): x[32:, asfm.shortlist_size:] *= 0. asfm.head.weight.detach()[:asfm.shortlist_size, - asfm.shortlist_size:] *= 0. + asfm.shortlist_size:] *= 0. asfm.head.weight.detach()[asfm.shortlist_size:, : - asfm.shortlist_size] *= 0. + asfm.shortlist_size] *= 0. out = asfm.predict(x) np.testing.assert_array_almost_equal( out, asfm.log_prob(x).argmax(axis=1)) - def linear_ref(self, x, weight, bias): - x = x.numpy() if isinstance(x, paddle.Tensor) else x - weight = weight.numpy() if isinstance(weight, paddle.Tensor) else weight - bias = bias.numpy() if isinstance(bias, paddle.Tensor) else bias - return np.matmul(x, weight) + bias - - def test_approx_to_lsfm(self): - """ - Test error between AdaptiveLogSoftmaxWithLoss and log_softmax less then 3%, according to http://arxiv.org/abs/1609.04309 - """ - x = paddle.abs(paddle.randn((64, 8))) - asfm = nn.AdaptiveLogSoftmaxWithLoss( - 8, 10, [6, 8], div_value=2., head_bias=True) - - head_weight = asfm.head.weight.detach() - head_bias = asfm.head.bias.detach() - ref_head_output = self.linear_ref(x, head_weight, head_bias) - - out = asfm.log_prob(x).argmax(axis=1) - ref = F.log_softmax( - paddle.to_tensor( - ref_head_output, dtype='float32')).argmax(axis=1) - self.assertTrue(out[out != ref].shape[0] < out.shape[0] / 30) - if __name__ == "__main__": unittest.main()