Skip to content

Commit 4869274

Browse files
authored
Refactor sum layers + consistency checks + bug fixes (#314)
* merged dense and mixing layer into sum layer over the whole lib * refactor parameterization of sum layers * re-run notebooks * cleaning * add MixingWeightInitializer as a way to symbolically initialize sum layers with arity > 1 * add sanity checks to the symbolic representation * updated torch embedding layer impl * minor fix when fold=False * re-run notebooks * add a few more symbolic checks * fixed #319 NaNs with logic circuits * moved from mixing weights initializer to mixing weights parameterizer * minorr fixes * rerun some notebooks * add docstrings about the mixing weight factory * cleaning imports * updated region-graphs-and-parametrisation.ipynb
1 parent 1ea1983 commit 4869274

33 files changed

+771
-748
lines changed

cirkit/backend/torch/compiler.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
match_optimization_patterns,
2222
optimize_graph,
2323
)
24-
from cirkit.backend.torch.initializers import stacked_initializer_
24+
from cirkit.backend.torch.initializers import foldwise_initializer_
2525
from cirkit.backend.torch.layers import TorchInputLayer, TorchLayer
2626
from cirkit.backend.torch.layers.input import TorchConstantLayer
2727
from cirkit.backend.torch.optimization.layers import (
@@ -370,7 +370,7 @@ def _fold_parameter_nodes_group(
370370
num_folds=len(group),
371371
requires_grad=group[0].requires_grad,
372372
initializer_=functools.partial(
373-
stacked_initializer_, initializers=list(map(lambda p: p.initializer, group))
373+
foldwise_initializer_, initializers=list(map(lambda p: p.initializer, group))
374374
),
375375
dtype=group[0].dtype,
376376
)
@@ -548,6 +548,7 @@ def _match_layer_pattern(
548548
outcomings_fn: Callable[[TorchLayer], Sequence[TorchLayer]],
549549
) -> LayerOptMatch | None:
550550
ppatterns = pattern.ppatterns()
551+
cpatterns = pattern.cpatterns()
551552
pattern_entries = pattern.entries()
552553
num_entries = len(pattern_entries)
553554
matched_layers = []
@@ -566,7 +567,12 @@ def _match_layer_pattern(
566567
if len(out_nodes) > 1 and lid != 0:
567568
return None
568569

569-
# Second, attempt to match the patterns specified for its parameters
570+
# Second, attempt to match the configuration patterns for the layer
571+
for cname, cvalue in cpatterns[lid].items():
572+
if layer.config[cname] != cvalue:
573+
return None
574+
575+
# Third, attempt to match the patterns specified for its parameters
570576
lpmatches = {}
571577
for pname, ppattern in ppatterns[lid].items():
572578
pgraph = layer.params[pname]

cirkit/backend/torch/initializers.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
InitializerFunc = Callable[[Tensor], Tensor]
88

99

10-
def copy_from_ndarray_(tensor: torch.Tensor, *, array: np.ndarray) -> Tensor:
10+
def foldwise_initializer_(t: Tensor, *, initializers: list[InitializerFunc | None]) -> Tensor:
11+
for i, initializer_ in enumerate(initializers):
12+
if initializer_ is not None:
13+
initializer_(t[i])
14+
return t
15+
16+
17+
def copy_from_ndarray_(tensor: Tensor, *, array: np.ndarray) -> Tensor:
1118
t = torch.from_numpy(array)
1219
default_float_dtype = torch.get_default_dtype()
1320
if t.is_floating_point():
@@ -21,7 +28,7 @@ def copy_from_ndarray_(tensor: torch.Tensor, *, array: np.ndarray) -> Tensor:
2128
return tensor.copy_(t)
2229

2330

24-
def dirichlet_(tensor: torch.Tensor, alpha: float | list[float], *, dim: int = -1) -> Tensor:
31+
def dirichlet_(tensor: Tensor, alpha: float | list[float], *, dim: int = -1) -> Tensor:
2532
shape = tensor.shape
2633
if len(shape) == 0:
2734
raise ValueError(
@@ -35,15 +42,8 @@ def dirichlet_(tensor: torch.Tensor, alpha: float | list[float], *, dim: int = -
3542
raise ValueError(
3643
"The selected dim of the tensor and the size of concentration parameters do not match"
3744
)
38-
concentration = torch.tensor(alpha)
45+
concentration = Tensor(alpha)
3946
dirichlet = torch.distributions.Dirichlet(concentration)
4047
samples = dirichlet.sample(torch.Size([d for i, d in enumerate(shape) if i != dim]))
4148
tensor.copy_(torch.transpose(samples, dim, -1))
4249
return tensor
43-
44-
45-
def stacked_initializer_(t: Tensor, *, initializers: list[InitializerFunc | None]) -> Tensor:
46-
for i, initializer_ in enumerate(initializers):
47-
if initializer_ is not None:
48-
initializer_(t[i])
49-
return t

cirkit/backend/torch/layers/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from .base import TorchLayer as TorchLayer
2-
from .inner import TorchDenseLayer as TorchDenseLayer
32
from .inner import TorchHadamardLayer as TorchHadamardLayer
43
from .inner import TorchInnerLayer as TorchInnerLayer
54
from .inner import TorchKroneckerLayer as TorchKroneckerLayer
6-
from .inner import TorchMixingLayer as TorchMixingLayer
7-
from .inner import TorchProductLayer as TorchProductLayer
85
from .inner import TorchSumLayer as TorchSumLayer
96
from .input import TorchCategoricalLayer as TorchCategoricalLayer
107
from .input import TorchConstantValueLayer as TorchLogPartitionLayer

cirkit/backend/torch/layers/inner.py

+23-108
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
4646
raise TypeError(f"Sampling not implemented for {type(self)}")
4747

4848

49-
class TorchProductLayer(TorchInnerLayer, ABC):
50-
...
51-
52-
53-
class TorchSumLayer(TorchInnerLayer, ABC):
54-
...
55-
56-
57-
class TorchHadamardLayer(TorchProductLayer):
49+
class TorchHadamardLayer(TorchInnerLayer):
5850
"""The Hadamard product layer."""
5951

6052
def __init__(
@@ -110,7 +102,7 @@ def sample(self, x: Tensor) -> tuple[Tensor, None]:
110102
return x, None
111103

112104

113-
class TorchKroneckerLayer(TorchProductLayer):
105+
class TorchKroneckerLayer(TorchInnerLayer):
114106
"""The Kronecker product layer."""
115107

116108
def __init__(
@@ -171,13 +163,14 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
171163
return torch.flatten(x, start_dim=2, end_dim=3), None
172164

173165

174-
class TorchDenseLayer(TorchSumLayer):
175-
"""The sum layer for dense sum within a layer."""
166+
class TorchSumLayer(TorchInnerLayer):
167+
"""The sum layer."""
176168

177169
def __init__(
178170
self,
179171
num_input_units: int,
180172
num_output_units: int,
173+
arity: int = 1,
181174
*,
182175
weight: TorchParameter,
183176
semiring: Semiring | None = None,
@@ -192,91 +185,7 @@ def __init__(
192185
num_folds (int): The number of channels. Defaults to 1.
193186
"""
194187
assert weight.num_folds == num_folds
195-
assert weight.shape == (num_output_units, num_input_units)
196-
super().__init__(
197-
num_input_units, num_output_units, arity=1, semiring=semiring, num_folds=num_folds
198-
)
199-
self.weight = weight
200-
201-
@property
202-
def config(self) -> Mapping[str, Any]:
203-
return {"num_input_units": self.num_input_units, "num_output_units": self.num_output_units}
204-
205-
@property
206-
def params(self) -> Mapping[str, TorchParameter]:
207-
return {"weight": self.weight}
208-
209-
def forward(self, x: Tensor) -> Tensor:
210-
"""Run forward pass.
211-
212-
Args:
213-
x (Tensor): The input to this layer, shape (F, H, B, Ki).
214-
215-
Returns:
216-
Tensor: The output of this layer, shape (F, B, Ko).
217-
"""
218-
x = x.squeeze(dim=1) # shape (F, H=1, B, Ki) -> (F, B, Ki).
219-
weight = self.weight()
220-
return self.semiring.einsum(
221-
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
222-
) # shape (F, B, Ko).
223-
224-
def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
225-
weight = self.weight()
226-
negative = torch.any(weight < 0.0)
227-
if negative:
228-
raise ValueError("Sampling only works with positive weights")
229-
normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
230-
if not normalized:
231-
raise ValueError("Sampling only works with a normalized parametrization")
232-
233-
# x: (F, H, C, K, num_samples, D)
234-
c = x.shape[2]
235-
d = x.shape[-1]
236-
num_samples = x.shape[-2]
237-
238-
# mixing_distribution: (F, O, K)
239-
mixing_distribution = torch.distributions.Categorical(probs=weight)
240-
241-
mixing_samples = mixing_distribution.sample((num_samples,))
242-
mixing_samples = E.rearrange(mixing_samples, "n f o -> f o n")
243-
mixing_indices = E.repeat(mixing_samples, "f o n -> f a c o n d", a=self.arity, c=c, d=d)
244-
245-
x = torch.gather(x, dim=-3, index=mixing_indices)
246-
x = x[:, 0]
247-
return x, mixing_samples
248-
249-
250-
class TorchMixingLayer(TorchSumLayer):
251-
"""The sum layer for mixture among layers.
252-
253-
It can also be used as a sparse sum within a layer when arity=1.
254-
"""
255-
256-
def __init__(
257-
self,
258-
num_input_units: int,
259-
num_output_units: int,
260-
arity: int = 2,
261-
*,
262-
weight: TorchParameter,
263-
semiring: Semiring | None = None,
264-
num_folds: int = 1,
265-
) -> None:
266-
"""Init class.
267-
268-
Args:
269-
num_input_units (int): The number of input units.
270-
num_output_units (int): The number of output units, must be the same as input.
271-
arity (int, optional): The arity of the layer. Defaults to 2.
272-
weight (TorchParameter): The reparameterization for layer parameters.
273-
num_folds (int): The number of channels. Defaults to 1.
274-
"""
275-
assert (
276-
num_output_units == num_input_units
277-
), "The number of input and output units must be the same for MixingLayer."
278-
assert weight.num_folds == num_folds
279-
assert weight.shape == (num_output_units, arity)
188+
assert weight.shape == (num_output_units, arity * num_input_units)
280189
super().__init__(
281190
num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
282191
)
@@ -303,11 +212,13 @@ def forward(self, x: Tensor) -> Tensor:
303212
Returns:
304213
Tensor: The output of this layer, shape (F, B, Ko).
305214
"""
306-
# shape (F, H, B, K) -> (F, B, K).
215+
# x: (F, H, B, Ki) -> (F, B, H * Ki)
216+
# weight: (F, Ko, H * Ki)
217+
x = x.permute(0, 2, 1, 3).flatten(start_dim=2)
307218
weight = self.weight()
308219
return self.semiring.einsum(
309-
"fhbk,fkh->fbk", inputs=(x,), operands=(weight,), dim=1, keepdim=False
310-
)
220+
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
221+
) # shape (F, B, Ko).
311222

312223
def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
313224
weight = self.weight()
@@ -318,18 +229,22 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
318229
if not normalized:
319230
raise ValueError("Sampling only works with a normalized parametrization")
320231

321-
# x: (F, H, C, K, num_samples, D)
322-
c = x.shape[2]
323-
k = x.shape[-3]
324-
d = x.shape[-1]
325-
num_samples = x.shape[-2]
232+
# x: (F, H, C, Ki, num_samples, D) -> (F, C, H * Ki, num_samples, D)
233+
x = x.permute(0, 2, 1, 3, 4, 5).flatten(2, 3)
234+
c = x.shape[1]
235+
num_samples = x.shape[3]
236+
d = x.shape[4]
326237

327-
# mixing_distribution: (F, O, K)
238+
# mixing_distribution: (F, Ko, H * Ki)
328239
mixing_distribution = torch.distributions.Categorical(probs=weight)
329240

241+
# mixing_samples: (num_samples, F, Ko) -> (F, Ko, num_samples)
330242
mixing_samples = mixing_distribution.sample((num_samples,))
331243
mixing_samples = E.rearrange(mixing_samples, "n f k -> f k n")
332-
mixing_indices = E.repeat(mixing_samples, "f k n -> f 1 c k n d", c=c, k=k, d=d)
333244

334-
x = torch.gather(x, 1, mixing_indices)[:, 0]
245+
# mixing_indices: (F, C, Ko, num_samples, D)
246+
mixing_indices = E.repeat(mixing_samples, "f k n -> f c k n d", c=c, d=d)
247+
248+
# x: (F, C, Ko, num_samples, D)
249+
x = torch.gather(x, dim=2, index=mixing_indices)
335250
return x, mixing_samples

cirkit/backend/torch/layers/input.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,19 @@ def params(self) -> Mapping[str, TorchParameter]:
166166
def forward(self, x: Tensor) -> Tensor:
167167
if x.is_floating_point():
168168
x = x.long() # The input to Embedding should be discrete
169-
x = F.one_hot(x, self.num_states) # (F, C, B, 1 num_states)
170-
x = x.squeeze(dim=3) # (F, C, B, num_states)
169+
x = x.squeeze(dim=3) # (F, C, B)
171170
weight = self.weight()
172-
x = torch.einsum("fcbi,fkci->fbkc", x.to(weight.dtype), weight)
173-
x = self.semiring.map_from(x, SumProductSemiring)
174-
return self.semiring.prod(x, dim=-1) # (F, B, K)
171+
if self.num_channels == 1:
172+
idx_fold = torch.arange(self.num_folds, device=weight.device)
173+
x = weight[:, :, 0][idx_fold[:, None], :, x[:, 0]]
174+
x = self.semiring.map_from(x, SumProductSemiring)
175+
else:
176+
idx_fold = torch.arange(self.num_folds, device=weight.device)[:, None, None]
177+
idx_channel = torch.arange(self.num_channels, device=weight.device)[None, :, None]
178+
x = weight[idx_fold, :, idx_channel, x]
179+
x = self.semiring.map_from(x, SumProductSemiring)
180+
x = self.semiring.prod(x, dim=1)
181+
return x # (F, B, K)
175182

176183

177184
class TorchExpFamilyLayer(TorchInputLayer, ABC):
@@ -332,7 +339,7 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
332339
x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]]
333340
else:
334341
idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None]
335-
idx_channel = torch.arange(self.num_channels)[None, :, None]
342+
idx_channel = torch.arange(self.num_channels, device=logits.device)[None, :, None]
336343
x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1)
337344
return x
338345

cirkit/backend/torch/layers/optimized.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
1-
from abc import ABC
21
from collections.abc import Mapping
32
from typing import Any
43

54
import einops as E
65
import torch
76
from torch import Tensor
87

9-
from cirkit.backend.torch.layers import TorchInnerLayer, TorchSumLayer
8+
from cirkit.backend.torch.layers import TorchInnerLayer
109
from cirkit.backend.torch.parameters.parameter import TorchParameter
1110
from cirkit.backend.torch.semiring import Semiring
1211

1312

14-
class TorchSumProductLayer(TorchInnerLayer, ABC):
15-
...
16-
17-
18-
class TorchTuckerLayer(TorchSumProductLayer):
13+
class TorchTuckerLayer(TorchInnerLayer):
1914
"""The Tucker (2) layer, which is a fused dense-kronecker.
2015
2116
A ternary einsum is used to fuse the sum and product.
@@ -81,7 +76,7 @@ def forward(self, x: Tensor) -> Tensor:
8176
)
8277

8378

84-
class TorchCPTLayer(TorchSumProductLayer):
79+
class TorchCPTLayer(TorchInnerLayer):
8580
"""The Candecomp Parafac (collapsed) layer, which is a fused dense-hadamard.
8681
8782
The fusion actually does not gain anything, and is just a plain connection. We don't because \
@@ -173,7 +168,7 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
173168
return x, mixing_samples
174169

175170

176-
class TorchTensorDotLayer(TorchSumLayer):
171+
class TorchTensorDotLayer(TorchInnerLayer):
177172
"""The sum layer for dense sum within a layer."""
178173

179174
def __init__(

0 commit comments

Comments
 (0)