Skip to content

Commit

Permalink
Merge pull request #561 from odlgroup/second_type_kl_divergence
Browse files Browse the repository at this point in the history
ENH: Add proximal operator for second kind of KL-divergence.
  • Loading branch information
aringh authored Sep 14, 2016
2 parents 411b4d0 + 994f754 commit 10d2cdb
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 44 deletions.
209 changes: 184 additions & 25 deletions odl/solvers/advanced/proximal_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from builtins import super

import numpy as np
import scipy.special

from odl.operator import (Operator, IdentityOperator, ScalingOperator,
ConstantOperator, ResidualOperator, DiagonalOperator)
Expand All @@ -45,7 +46,7 @@
'proximal_l1', 'proximal_cconj_l1',
'proximal_l2', 'proximal_cconj_l2',
'proximal_l2_squared', 'proximal_cconj_l2_squared',
'proximal_cconj_kl')
'proximal_cconj_kl', 'proximal_cconj_kl_cross_entropy')


# TODO: remove diagonal op once available on master
Expand Down Expand Up @@ -982,34 +983,18 @@ def proximal_cconj_kl(space, lam=1, g=None):
Function returning the proximal operator of the convex conjugate of the
functional F where F is the entropy-type Kullback-Leibler (KL) divergence
F(x) = sum_i (x - g + g ln(g) - g ln(pos(x)))_i + ind_P(x)
F(x) = sum_i (x_i - g_i + g_i ln(g_i) - g_i ln(pos(x_i))) + ind_P(x)
with x and g in X and g non-negative. The indicator function ind_P(x)
for the positive elements of x is used to restrict the domain of F such
that F is defined over whole X. The non-negativity thresholding pos is
used to define F in the real numbers.
The proximal operator of the convex conjugate F^* of F is
F^*(p) = sum_i (-g ln(pos(1_X - p))_i + ind_P(1_X - p)
where p is the variable dual to x, and 1_X is an element of the space
X with all components set to 1.
The proximal operator of the convex conjugate of F is
prox[sigma * F^*](x) =
1/2 (lam + x - sqrt((x - lam)^2 + 4 lam sigma g)
with the step size parameter sigma and lam_X is an element of the space X
with all components set to lam.
with ``x`` and ``g`` elements in the linear space ``X``, and ``g``
non-negative. Here, ``pos`` denotes the nonnegative part, and ``ind_P`` is
the indicator function for nonnegativity.
Parameters
----------
space : `DiscreteLp` or `ProductSpace` of `DiscreteLp` spaces
space : `FnBase`
Space X which is the domain of the functional F
g : ``space`` element
Data term.
g : ``space`` element, optional
Data term, positive. If None it is take as the one-element.
lam : positive float
Scaling factor.
Expand All @@ -1018,8 +1003,48 @@ def proximal_cconj_kl(space, lam=1, g=None):
prox_factory : function
Factory for the proximal operator to be initialized.
See Also
--------
proximal_cconj_kl_cross_entropy : proximal for releated functional
Notes
-----
The functional is given by the expression
.. math::
F(x) = \\sum_i (x_i - g_i + g_i \\ln(g_i) - g_i \\ln(pos(x_i))) +
I_{x \\geq 0}(x)
The indicator function :math:`I_{x \geq 0}(x)` is used to restrict the
domain of :math:`F` such that :math:`F` is defined over whole space
:math:`X`. The non-negativity thresholding :math:`pos` is used to define
:math:`F` in the real numbers.
Note that the functional is not well-defined without a prior g. Hence, if g
is omitted this will be interpreted as if g is equal to the one-element.
The convex conjugate :math:`F^*` of :math:`F` is
.. math::
F^*(p) = \\sum_i (-g_i \\ln(pos({1_X}_i - p_i))) +
I_{1_X - p \geq 0}(p)
where :math:`p` is the variable dual to :math:`x`, and :math:`1_X` is an
element of the space :math:`X` with all components set to 1.
The proximal operator of the convex conjugate of F is
.. math::
prox[\\sigma * (\\lambda*F)^*](x) =
\\frac{\\lambda * 1_X + x - \\sqrt{(x - \\lambda * 1_X)^2 +
4 \\lambda \\sigma g}}{2}
where :math:`\\sigma` is the step size-like parameter, and :math:`\\lambda`
is the weighting in front of the function :math:`F`.
KL based objectives are common in MLEM optimization problems and are often
used when data noise governed by a multivariate Poisson probability
distribution is significant.
Expand All @@ -1028,7 +1053,15 @@ def proximal_cconj_kl(space, lam=1, g=None):
the converged solution will be non-negative. Non-negative intermediate
image estimates can be enforced by adding an indicator function ind_P
the primal objective.
This functional :math:`F`, described above, is related to the
Kullback-Leibler cross entropy functional. The KL cross entropy is the one
diescribed in `this Wikipedia article
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_, and
the functional :math:`F` is obtained by switching place of the prior and
the varialbe in the KL cross entropy functional. See the See Also section.
"""

lam = float(lam)

if g is not None and g not in space:
Expand Down Expand Up @@ -1061,7 +1094,10 @@ def _call(self, x, out):
out.ufunc.square(out=out)

# out = out + 4 lam sigma g
if g is not None:
# If g is None, it is taken as the one element
if g is None:
out += 4.0 * lam * self.sigma
else:
out.lincomb(1, out, 4.0 * lam * self.sigma, g)

# out = sqrt(out)
Expand All @@ -1079,6 +1115,129 @@ def _call(self, x, out):
return ProximalCConjKL


def proximal_cconj_kl_cross_entropy(space, lam=1, g=None):
"""Proximal factory of the convex conjugate of cross entropy KL divergence.
Function returning the proximal facotry of the convex conjugate of the
functional F, where F is the corss entorpy Kullback-Leibler (KL)
divergence given by
F(x) = sum_i (x_i ln(pos(x_i)) - x_i ln(g_i) + g_i - x_i) + ind_P(x)
with ``x`` and ``g`` in the linear space ``X``, and ``g`` non-negative.
Here, ``pos`` denotes the nonnegative part, and ``ind_P`` is the indicator
function for nonnegativity.
Parameters
----------
space : `FnBase`
Space X which is the domain of the functional F
g : ``space`` element, optional
Data term, positive. If None it is take as the one-element.
lam : positive float
Scaling factor.
Returns
-------
prox_factory : function
Factory for the proximal operator to be initialized.
See Also
--------
proximal_cconj_kl : proximal for related functional
Notes
-----
The functional is given by the expression
.. math::
F(x) = \\sum_i (x_i \\ln(pos(x_i)) - x_i \\ln(g_i) + g_i - x_i) +
I_{x \\geq 0}(x)
The indicator function :math:`I_{x \geq 0}(x)` is used to restrict the
domain of :math:`F` such that :math:`F` is defined over whole space
:math:`X`. The non-negativity thresholding :math:`pos` is used to define
:math:`F` in the real numbers.
Note that the functional is not well-defined without a prior g. Hence, if g
is omitted this will be interpreted as if g is equal to the one-element.
The convex conjugate :math:`F^*` of :math:`F` is
.. math::
F^*(p) = \\sum_i g_i (exp(p_i) - 1)
where :math:`p` is the variable dual to :math:`x`.
The proximal operator of the convex conjugate of :math:`F` is
.. math::
prox[\\sigma * (\\lambda*F)^*](x)_i = x_i - \\lambda
W(\\frac{\\sigma}{\\lambda} g_i e^{x_i/\\lambda})
where :math:`\\sigma` is the step size-like parameter, :math:`\\lambda` is
the weighting in front of the function :math:`F`, and :math:`W` is the
Lambert W function (see, for example, the
`Wikipedia article <https://en.wikipedia.org/wiki/Lambert_W_function>`_).
For real-valued input x, the Lambert :math:`W` function is defined only for
:math:`x \\geq -1/e`, and it has two branches for values
:math:`-1/e \\leq x < 0`. However, for inteneded use-cases, where
:math:`\\lambda` and :math:`g` are positive, the argument of :math:`W`
will always be positive.
`Wikipedia article on Kullback Leibler divergence
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_.
For further information about the functional, see for example `this article
<http://ieeexplore.ieee.org/document/1056144/?arnumber=1056144>`_.
The KL cross entropy functional :math:`F`, described above, is related to
another functional functional also know as KL divergence. This functional
is often used as data discrepancy term in inverse problems, when data is
corrupted with Poisson noise. This functional is obtained by changing place
of the prior and the variable. See the See Also section.
"""

lam = float(lam)

if g is not None and g not in space:
raise TypeError('{} is not an element of {}'.format(g, space))

class ProximalCConjKLCrossEntropy(Operator):

"""Proximal operator of conjugate of cross entropy KL divergence."""

def __init__(self, sigma):
"""Initialize a new instance.
Parameters
----------
sigma : positive float
"""
self.sigma = float(sigma)
super().__init__(domain=space, range=space, linear=False)

def _call(self, x, out):
"""Apply the operator to ``x`` and stores the result in ``out``."""

if g is None:
# If g is None, it is taken as the one element
# Different branches of lambertw is not an issue, see Notes
out.lincomb(1, x, -lam, scipy.special.lambertw(
(self.sigma / lam) * np.exp(x / lam)))
else:
# Different branches of lambertw is not an issue, see Notes
out.lincomb(1, x,
-lam, scipy.special.lambertw(
(self.sigma / lam) * g * np.exp(x / lam)))

return ProximalCConjKLCrossEntropy


if __name__ == '__main__':
# pylint: disable=wrong-import-position
from odl.util.testutils import run_doctests
Expand Down
91 changes: 86 additions & 5 deletions test/largescale/solvers/advanced/proximal_operator_slow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
from odl.solvers.advanced.proximal_operators import (
proximal_l1, proximal_cconj_l1,
proximal_l2, proximal_cconj_l2,
proximal_l2_squared, proximal_cconj_l2_squared)
from odl.util.testutils import noise_element
proximal_l2_squared, proximal_cconj_l2_squared,
proximal_cconj_kl, proximal_cconj_kl_cross_entropy)
from odl.util.testutils import (noise_element, all_almost_equal)

from scipy.special import lambertw

pytestmark = odl.util.skip_if_no_largescale

Expand Down Expand Up @@ -71,10 +73,10 @@ def offset_function(function):
assert False
return offset_function


prox_params = ['l1 ', 'l1_dual',
'l2', 'l2_dual',
'l2^2', 'l2^2_dual']
'l2^2', 'l2^2_dual',
'kl_dual', 'kl_cross_ent_dual']
prox_ids = [' f = {}'.format(p.ljust(10)) for p in prox_params]


Expand Down Expand Up @@ -145,6 +147,42 @@ def l2_norm_squared_dual(x):

return prox(stepsize), l2_norm_squared_dual

elif name == 'kl_dual':
if g is not None:
g = np.abs(g)

def kl_divergence_dual(x):
if np.greater_equal(x, 1):
return np.Infinity
else:
one_element = x.space.one()
if g is None:
return stepsize * one_element.inner(
np.log(one_element - x))
else:
return stepsize * one_element.inner(
g * np.log(one_element - x))

prox = proximal_cconj_kl(space, g=g)

return prox(stepsize), kl_divergence_dual

elif name == 'kl_cross_ent_dual':
if g is not None:
g = np.abs(g)

def kl_divergence_cross_entropy_dual(x):
one_element = x.space.one()
if g is None:
return stepsize * one_element.inner(np.exp(x) - one_element)
else:
return stepsize * one_element.inner(
g * (np.exp(x) - one_element))

prox = proximal_cconj_kl_cross_entropy(space, g=g)

return prox(stepsize), kl_divergence_cross_entropy_dual

else:
assert False

Expand All @@ -163,7 +201,7 @@ def test_proximal_defintion(proximal_and_function):
x* = prox[f](x)
f(x*) + 1/2 ||x*-y||^2 < f(y) + 1/2 ||x-y||^2
f(x*) + 1/2 ||x-x*||^2 < f(y) + 1/2 ||x-y||^2
"""

proximal, function = proximal_and_function
Expand All @@ -183,5 +221,48 @@ def test_proximal_defintion(proximal_and_function):

assert f_prox_x <= f_y


def test_proximal_cconj_kl_cross_entropy_solving_opt_problem():
"""Test for proximal operator of conjguate of 2nd kind KL-divergecen.
The test solves the problem
min_x lam*KL(x | g) + 1/2||x-a||^2_2,
where g is the nonnegative prior, and a is any vector. Explicit solution
to this problem is given by
x = lam*W(g*e^(a/lam)/lam),
where W is the Lambert W function.
"""

# Image space
space = odl.uniform_discr(0, 1, 10)

# Data
g = space.element(np.arange(10, 0, -1))
a = space.element(np.arange(4, 14, 1))

# Creating and assembling linear operators and proximals
id_op = odl.IdentityOperator(space)
lin_ops = [id_op, id_op]
lam_kl = 2.3
prox_cc_g = [odl.solvers.proximal_cconj_kl_cross_entropy(space, lam=lam_kl,
g=g),
odl.solvers.proximal_cconj_l2_squared(space, lam=1.0 / 2.0,
g=a)]
prox_f = odl.solvers.proximal_zero(space)

# Staring point
x = space.zero()

odl.solvers.douglas_rachford_pd(x, prox_f, prox_cc_g, lin_ops,
tau=2.1, sigma=[0.4, 0.4], niter=100)

# Explicit solution: x = W(g * exp(a)), where W is the Lambert W function.
x_verify = lam_kl * lambertw((1.0 / lam_kl) * g * np.exp(a / lam_kl))
assert all_almost_equal(x, x_verify, places=6)

if __name__ == '__main__':
pytest.main(str(__file__.replace('\\', '/') + ' -v --largescale'))
Loading

0 comments on commit 10d2cdb

Please sign in to comment.