diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py new file mode 100644 index 00000000000000..07f8d6a542e858 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_adaptive_log_softmax_with_loss.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_error(self): + # args validation + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 15], div_value=2.) + + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 10], div_value=2.) + + with self.assertRaises(ValueError): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 25], div_value=2.) + + with self.assertRaisesRegex(ValueError, + "cutoffs should be a sequence of unique,"): + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.) + # not raise + _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.) + + def test_shape(self): + # input shapes + with self.assertRaisesRegex( + RuntimeError, r"Input and target should have the same size"): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.) + x = paddle.randn((2, 16)) + y = paddle.to_tensor([0, 5, 10]) + asfm(x, y) + + # out-of-bound targets + with self.assertRaisesRegex(RuntimeError, + r"Target values should be in"): + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 16, 20, [5, 10, 15], div_value=2.) + x = paddle.randn((128, 16)) + y = paddle.randint(low=21, high=200, shape=[128]) + asfm(x, y) + + def test_cluster(self): + # cluster sizes + asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) + x = paddle.randn((128, 16)) + y = paddle.randint(low=0, high=20, shape=[128]) + + self.assertEqual( + asfm.head.weight.shape, + [16, 5 + 3]) # 5 targets in head, 3 clusters, dimensionality 16 + self.assertEqual(asfm.tail[0][1].weight.shape, + [8, 5]) # 5 targets in this cluster, dimensionality 8 + self.assertEqual(asfm.tail[1][1].weight.shape, [4, 5]) + self.assertEqual(asfm.tail[2][1].weight.shape, [2, 5]) + + self.assertEqual(asfm(x, y).output.shape, [128]) + + def test_log_probs(self): + # log_probs actually returns log_proba + asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.) + x = paddle.randn((4, 8)) + logprob_out = asfm.log_prob(x) + np.testing.assert_array_almost_equal( + paddle.exp(logprob_out).sum(1), paddle.ones([4])) + + # forward returns the same thing as log_probs + for v in [0, 1, 2, 3]: + y = paddle.full((4, ), v, dtype='int64') + out, loss = asfm(x, y) + np.testing.assert_array_almost_equal( + out, + logprob_out.gather(y.unsqueeze(1), 1).slice([1], [0], + [1]).squeeze()) + np.testing.assert_array_almost_equal(loss, + F.nll_loss(logprob_out, y)) + + def test_correct(self): + # predict + x = paddle.abs(paddle.randn((64, 8))) + + # argmax in shortlist + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2., head_bias=True) + asfm.head.weight.detach().abs() + asfm.head.bias.detach().abs() + asfm.head.weight.detach()[asfm.shortlist_size:, :] *= 0. + + out = asfm.predict(x) + np.testing.assert_array_almost_equal( + out, asfm.log_prob(x).argmax(axis=1)) + + # argmax outside of shortlist + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2., head_bias=True) + asfm.head.weight.detach().abs() + asfm.head.bias.detach().abs() + asfm.head.weight.detach()[:asfm.shortlist_size, :] *= 0. + + out = asfm.predict(x) + np.testing.assert_array_almost_equal( + out, asfm.log_prob(x).argmax(axis=1)) + + # half of the argmax in shortlist, half in clusters + asfm = nn.AdaptiveLogSoftmaxWithLoss( + 8, 10, [4, 8], div_value=2., head_bias=True) + asfm.head.weight.detach().abs() + asfm.head.bias.detach().abs() + + x[:32, :asfm.shortlist_size] *= 0. + x[32:, asfm.shortlist_size:] *= 0. + + asfm.head.weight.detach()[:asfm.shortlist_size, + asfm.shortlist_size:] *= 0. + asfm.head.weight.detach()[asfm.shortlist_size:, : + asfm.shortlist_size] *= 0. + + out = asfm.predict(x) + np.testing.assert_array_almost_equal( + out, asfm.log_prob(x).argmax(axis=1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 1abe74e9783dc4..ea78f17f81d1b4 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -90,6 +90,7 @@ from .layer.conv import Conv2DTranspose # noqa: F401 from .layer.conv import Conv3DTranspose # noqa: F401 +from .layer.loss import AdaptiveLogSoftmaxWithLoss # noqa: F401 from .layer.loss import BCEWithLogitsLoss # noqa: F401 from .layer.loss import CrossEntropyLoss # noqa: F401 from .layer.loss import HSigmoidLoss # noqa: F401 @@ -276,6 +277,7 @@ def weight_norm(*args): 'Conv3DTranspose', 'Flatten', 'AdaptiveAvgPool1D', + 'AdaptiveLogSoftmaxWithLoss', 'Tanhshrink', 'HSigmoidLoss', 'PReLU', diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index a65f9912d59391..662ba6a87591c7 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -57,6 +57,7 @@ from .pooling import AdaptiveMaxPool2D # noqa: F401 from .pooling import AdaptiveMaxPool3D # noqa: F401 from .pooling import MaxUnPool2D # noqa: F401 +from .loss import AdaptiveLogSoftmaxWithLoss # noqa: F401 from .conv import Conv1D # noqa: F401 from .conv import Conv2D # noqa: F401 from .conv import Conv3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3ac0d675fb72c6..c2daa946116865 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -15,14 +15,280 @@ # TODO: define loss functions of neural network import numpy as np +from collections import namedtuple +from typing import List, Sequence import paddle.fluid as fluid import paddle.fluid.core as core import paddle from .. import functional as F from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator -from .. import Layer +from .. import Layer, Sequential, LayerList +from paddle.nn.layer import Linear +from paddle import Tensor __all__ = [] +_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) + + +class AdaptiveLogSoftmaxWithLoss(Layer): + r""" + BSD LICENSE From PyTorch: + Copyright (c) 2016- Facebook, Inc (Adam Paszke) + + Efficient softmax approximation as described in + `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, + Moustapha Cissé, David Grangier, and Hervé Jégou + `__. + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + We highly recommend taking a look at the original paper for more details. + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + Shape: + - input: :math:`(N, \texttt{in\_features})` + - target: :math:`(N)` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` + - output2: ``Scalar`` + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float=4., + head_bias: bool=False, ) -> None: + super(AdaptiveLogSoftmaxWithLoss, self).__init__() + + cutoffs = list(cutoffs) + + if (cutoffs != sorted(cutoffs)) \ + or (min(cutoffs) <= 0) \ + or (max(cutoffs) > (n_classes - 1)) \ + or (len(set(cutoffs)) != len(cutoffs)) \ + or any([int(c) != c for c in cutoffs]): + + raise ValueError("cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1") + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear( + self.in_features, self.head_size, bias_attr=self.head_bias) + self.tail = LayerList() + + for i in range(self.n_clusters): + + hsz = int(self.in_features // (self.div_value**(i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear( + self.in_features, hsz, bias_attr=False), + Linear( + hsz, osz, bias_attr=False), ) + + self.tail.append(projection) + + def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: + if input.shape[0] != target.shape[0]: + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + + used_rows = 0 + batch_size = target.shape[0] + + output = paddle.zeros([batch_size]) + gather_inds = paddle.empty([batch_size]) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx).logical_and(target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + if i == 0: + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + target.masked_select(target_mask), gather_inds.shape) + # gather_inds = gather_inds * (scatter_output == 0) + scatter_output + gather_inds = scatter_output + else: + relative_target = target.masked_select(target_mask) - low_idx + input_subset = input.index_select(row_indices, 0) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + paddle.ones(row_indices.shape) * cluster_index, + gather_inds.shape) + gather_inds = (gather_inds * (scatter_output != cluster_index) + + scatter_output).astype(paddle.int64) + + cluster_logprob = F.log_softmax(cluster_output, axis=1) + local_logprob = (F.one_hot(relative_target, + cluster_logprob.shape[-1]) * + cluster_logprob).sum(1).unsqueeze(1) + scatter_output = paddle.scatter_nd( + row_indices.unsqueeze(1), + local_logprob.squeeze(1), output.shape) + output = output * (scatter_output == 0) + scatter_output + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError("Target values should be in [0, {}], " + "but values in range [{}, {}] " + "were found. ".format(self.n_classes - 1, + target.min().item(), + target.max().item())) + + head_output = self.head(input) + head_logprob = F.log_softmax(head_output, axis=1) + output += (paddle.nn.functional.one_hot( + gather_inds, head_logprob.shape[1]) * head_logprob).sum(1) + loss = (-output).mean() + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """ Given input tensor, and output of `self.head`, + compute the log of the full distribution """ + + out = paddle.empty((head_output.shape[0], self.n_classes)) + head_logprob = F.log_softmax(head_output, axis=1) + + out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] + + for i, (start_idx, + stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = F.log_softmax(cluster_output, axis=1) + output_logprob = cluster_logprob + head_logprob[:, + self.shortlist_size + + i].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + r""" Computes log probabilities for all :math:`\texttt{n\_classes}` + Args: + input (Tensor): a minibatch of examples + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + """ + + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + r""" This is equivalent to `self.log_pob(input).argmax(axis=1)`, + but is more efficient in some cases. + Args: + input (Tensor): a minibatch of examples + Returns: + output (Tensor): a class with the highest probability for each example + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + + head_output = self.head(input) + output = paddle.argmax(head_output, axis=1).cast('float32') + not_in_shortlist = (output >= self.shortlist_size) + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return paddle.argmax(log_prob, axis=1) + + else: + log_prob = self._get_full_log_prob(input[not_in_shortlist], + head_output[not_in_shortlist]) + output[not_in_shortlist] = paddle.argmax( + log_prob, axis=1).cast('float32') + return output class BCEWithLogitsLoss(Layer):