Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KDE for beam screen images #200

Merged
merged 18 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Add Python 3.12 support (see #161) (@jank324)
- Implement space charge using Green's function in a `SpaceChargeKick` element (see #142) (@greglenerd, @RemiLehe, @ax3l, @cr-xu, @jank324)
- `Segment`s can now be imported from Bmad to devices other than `torch.device("cpu")` and dtypes other than `torch.float32` (see #196, #206) (@jank324)
- `Screen` will now use KDE for differentiable images. (see #200) (@cr-xu, @roussel-ryan

### 🐛 Bug fixes

Expand Down
84 changes: 64 additions & 20 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.distributions import MultivariateNormal

from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, kde_histogram_2d

from .element import Element

Expand All @@ -20,15 +20,20 @@ class Screen(Element):
Diagnostic screen in a particle accelerator.

:param resolution: Resolution of the camera sensor looking at the screen given as
Tensor `(width, height)`.
Tensor `(width, height)` in pixels.
:param pixel_size: Size of a pixel on the screen in meters given as a Tensor
`(width, height)`.
:param binning: Binning used by the camera.
:param misalignment: Misalignment of the screen in meters given as a Tensor
`(x, y)`.
:param kde_bandwith: Bandwidth used for the kernel density estimation in meters.
Controls the smoothness of the distribution.
:param is_active: If `True` the screen is active and will record the beam's
distribution. If `False` the screen is inactive and will not record the beam's
distribution.
:param method: Method used to generate the screen's reading. Can be either
"histogram" or "kde", defaults to "histogram".
KDE will be slower but allows backward differentiation.
:param name: Unique identifier of the element.
"""

Expand All @@ -38,7 +43,9 @@ def __init__(
pixel_size: Optional[Union[torch.Tensor, nn.Parameter]] = None,
binning: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
kde_bandwith: Optional[Union[torch.Tensor, nn.Parameter]] = None,
is_active: bool = False,
method: str = "histogram",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
Expand Down Expand Up @@ -66,6 +73,16 @@ def __init__(
if misalignment is not None
else torch.tensor([(0.0, 0.0)], **factory_kwargs)
)
assert method in [
"histogram",
"kde",
], f"Invalid method {method}. Must be either 'histogram' or 'kde'."
self.method = method
self.kde_bandwith = (
torch.as_tensor(kde_bandwith, **factory_kwargs)
if kde_bandwith is not None
else torch.clone(self.pixel_size[0])
)
self.length = torch.zeros(self.misalignment.shape[:-1], **factory_kwargs)
self.is_active = is_active

Expand Down Expand Up @@ -110,6 +127,13 @@ def pixel_bin_edges(self) -> tuple[torch.Tensor, torch.Tensor]:
),
)

@property
def pixel_bin_centers(self) -> tuple[torch.Tensor, torch.Tensor]:
return (
(self.pixel_bin_edges[0][1:] + self.pixel_bin_edges[0][:-1]) / 2,
(self.pixel_bin_edges[1][1:] + self.pixel_bin_edges[1][:-1]) / 2,
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
device = self.misalignment.device
dtype = self.misalignment.dtype
Expand All @@ -121,11 +145,11 @@ def track(self, incoming: Beam) -> Beam:
copy_of_incoming = deepcopy(incoming)

if isinstance(incoming, ParameterBeam):
copy_of_incoming._mu[:, 0] -= self.misalignment[:, 0]
copy_of_incoming._mu[:, 2] -= self.misalignment[:, 1]
copy_of_incoming._mu[..., 0] -= self.misalignment[..., 0]
copy_of_incoming._mu[..., 2] -= self.misalignment[..., 1]
elif isinstance(incoming, ParticleBeam):
copy_of_incoming.particles[:, :, 0] -= self.misalignment[:, 0]
copy_of_incoming.particles[:, :, 1] -= self.misalignment[:, 1]
copy_of_incoming.particles[..., :, 0] -= self.misalignment[..., 0]
copy_of_incoming.particles[..., :, 1] -= self.misalignment[..., 1]

self.set_read_beam(copy_of_incoming)

Expand Down Expand Up @@ -188,22 +212,38 @@ def reading(self) -> torch.Tensor:
)
image = torch.flip(image, dims=[1])
elif isinstance(read_beam, ParticleBeam):
image = torch.zeros(
(
*self.misalignment.shape[:-1],
int(self.effective_resolution[1]),
int(self.effective_resolution[0]),

if self.method == "histogram":
image = torch.zeros(
(
*self.misalignment.shape[:-1],
int(self.effective_resolution[1]),
int(self.effective_resolution[0]),
)
)
)
for i, (xs_sample, ys_sample) in enumerate(zip(read_beam.xs, read_beam.ys)):
image_sample, _ = torch.histogramdd(
torch.stack((xs_sample, ys_sample)).T.cpu(),
bins=self.pixel_bin_edges,
for i, (xs_sample, ys_sample) in enumerate(
zip(read_beam.xs, read_beam.ys)
):
image_sample, _ = torch.histogramdd(
torch.stack((xs_sample, ys_sample)).T.cpu(),
bins=self.pixel_bin_edges,
)
image_sample = torch.flipud(image_sample.T)
image_sample = image_sample.cpu()

image[i] = image_sample
elif self.method == "kde":
image = kde_histogram_2d(
x1=read_beam.xs,
x2=read_beam.ys,
bins1=self.pixel_bin_centers[0],
bins2=self.pixel_bin_centers[1],
bandwidth=self.kde_bandwith,
)
image_sample = torch.flipud(image_sample.T)
image_sample = image_sample.cpu()

image[i] = image_sample
# Change the x, y positions
image = torch.transpose(image, -2, -1)
# Flip up an down, now row 0 corresponds to the top
image = torch.flip(image, dims=[-2])
else:
raise TypeError(f"Read beam is of invalid type {type(read_beam)}")

Expand Down Expand Up @@ -252,6 +292,8 @@ def defining_features(self) -> list[str]:
"pixel_size",
"binning",
"misalignment",
"method",
"kde_bandwith",
"is_active",
]

Expand All @@ -261,6 +303,8 @@ def __repr__(self) -> str:
+ f"pixel_size={repr(self.pixel_size)}, "
+ f"binning={repr(self.binning)}, "
+ f"misalignment={repr(self.misalignment)}, "
+ f"method={repr(self.method)}, "
+ f"kde_bandwith={repr(self.kde_bandwith)}, "
+ f"is_active={repr(self.is_active)}, "
+ f"name={repr(self.name)})"
)
176 changes: 176 additions & 0 deletions cheetah/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""Utility functions"""

import math
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


class UniqueNameGenerator:
"""Generates a unique name given a prefix."""

Expand All @@ -9,3 +18,170 @@ def __call__(self):
name = f"{self._prefix}_{self._counter}"
self._counter += 1
return name


# Kernel density Estimation functionalities
# Modified from kornia.enhance.histogram


def kde_marginal_pdf(
values: torch.Tensor,
bins: torch.Tensor,
sigma: torch.Tensor,
weights: Optional[Union[Tensor, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate the 1D marginal probability distribution function of the input tensor
based on the number of histogram bins.

Args:
values: shape [BxNx1].
bins: shape [NUM_BINS].
sigma: shape [1], gaussian smoothing factor.

Returns:
Tuple[torch.Tensor, torch.Tensor]:
- torch.Tensor: shape [BxN].
- torch.Tensor: shape [BxNxNUM_BINS].
"""

if not isinstance(values, torch.Tensor):
raise TypeError(f"Input values type is not a torch.Tensor. Got {type(values)}")

if not isinstance(bins, torch.Tensor):
raise TypeError(f"Input bins type is not a torch.Tensor. Got {type(bins)}")

if not isinstance(sigma, torch.Tensor):
raise TypeError(f"Input sigma type is not a torch.Tensor. Got {type(sigma)}")

if not bins.dim() == 1:
raise ValueError(
"Input bins must be a of the shape NUM_BINS" " Got {}".format(bins.shape)
)

if not sigma.dim() == 0:
raise ValueError(
"Input sigma must be a of the shape 1" " Got {}".format(sigma.shape)
)

if isinstance(weights, float):
weights = torch.ones(values.shape[:-1])
elif weights is None:
weights = 1.0

residuals = values - bins.repeat(*values.shape)
kernel_values = (
weights
* torch.exp(-0.5 * (residuals / sigma).pow(2))
/ torch.sqrt(2 * math.pi * sigma**2)
)

prob_mass = torch.sum(kernel_values, dim=-2)
return prob_mass, kernel_values


def kde_joint_pdf_2d(
kernel_values1: torch.Tensor, kernel_values2: torch.Tensor, epsilon: float = 1e-10
) -> torch.Tensor:
"""Calculate the joint probability distribution function of the input tensors based
on the number of histogram bins.

Args:
kernel_values1: shape [BxNxNUM_BINS].
kernel_values2: shape [BxNxNUM_BINS].
epsilon: scalar, for numerical stability.

Returns:
shape [BxNUM_BINSxNUM_BINS].
"""

if not isinstance(kernel_values1, torch.Tensor):
raise TypeError(
"Input kernel_values1 type is not a torch.Tensor."
+ f"Got {type(kernel_values1)}"
)

if not isinstance(kernel_values2, torch.Tensor):
raise TypeError(
"Input kernel_values2 type is not a torch.Tensor."
+ f"Got {type(kernel_values2)}"
)

joint_kernel_values = torch.matmul(kernel_values1.transpose(-2, -1), kernel_values2)
normalization = (
torch.sum(joint_kernel_values, dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1)
+ epsilon
)
pdf = joint_kernel_values / normalization

return pdf


def kde_histogram_1d(
x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor, epsilon: float = 1e-10
) -> torch.Tensor:
"""Estimate the histogram using KDE of the input tensor.

The calculation uses kernel density estimation which requires a bandwidth
(smoothing) parameter.

Args:
x: Input tensor to compute the histogram with shape :math:`(B, D)`.
bins: The number of bins to use the histogram :math:`(N_{bins})`.
bandwidth: Gaussian smoothing factor with shape shape [1].
epsilon: A scalar, for numerical stability.

Returns:
Computed histogram of shape :math:`(B, N_{bins})`.

Examples:
>>> x = torch.rand(1, 10)
>>> bins = torch.torch.linspace(0, 255, 128)
>>> hist = kde_histogram_1d(x, bins, bandwidth=torch.tensor(0.9))
>>> hist.shape
torch.Size([1, 128])
"""

pdf, _ = kde_marginal_pdf(x.unsqueeze(-1), bins, bandwidth, epsilon)

return pdf


def kde_histogram_2d(
x1: torch.Tensor,
x2: torch.Tensor,
bins1: torch.Tensor,
bins2: torch.Tensor,
bandwidth: torch.Tensor,
weights=None,
epsilon: Union[float, torch.Tensor] = 1e-10,
) -> torch.Tensor:
"""Estimate the 2d histogram of the input tensor.

The calculation uses kernel density estimation which requires a bandwidth
(smoothing) parameter.

Args:
x1: Input tensor to compute the histogram with shape :math:`(B, D1)`.
x2: Input tensor to compute the histogram with shape :math:`(B, D2)`.
bins: bin coordinates.
bandwidth: Gaussian smoothing factor with shape shape [1].
epsilon: A scalar, for numerical stability. Default: 1e-10.

Returns:
Computed histogram of shape :math:`(B, N_{bins}), N_{bins})`.

Examples:
>>> x1 = torch.rand(2, 32)
>>> x2 = torch.rand(2, 32)
>>> bins = torch.torch.linspace(0, 255, 128)
>>> hist = kde_histogram_2d(x1, x2, bins, bandwidth=torch.tensor(0.9))
>>> hist.shape
torch.Size([2, 128, 128])
"""

_, kernel_values1 = kde_marginal_pdf(x1.unsqueeze(-1), bins1, bandwidth, weights)
_, kernel_values2 = kde_marginal_pdf(x2.unsqueeze(-1), bins2, bandwidth, weights)

pdf = kde_joint_pdf_2d(kernel_values1, kernel_values2, epsilon=epsilon)

return pdf
Loading