Skip to content

Commit

Permalink
Use numpy for transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
reznakt committed Feb 26, 2025
1 parent b83b04e commit f069674
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 154 deletions.
50 changes: 31 additions & 19 deletions svglab/attrparse/d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ClosePath(_PathCommandBase):
class LineTo(_HasEnd, _PhysicalPathCommand):
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(end=other @ self.end)


Expand All @@ -79,11 +79,11 @@ def __rmatmul__(
) -> Self: ...

@overload
def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self: ...
def __rmatmul__(self, other: transform.TransformFunction) -> Self: ...

@override
def __rmatmul__(
self, other: transform.SupportsToMatrix
self, other: transform.TransformFunction
) -> Self | LineTo:
if isinstance(other, transform.Translate | transform.Scale):
x, _ = other @ point.Point(self.x, 0)
Expand All @@ -107,11 +107,11 @@ def __rmatmul__(
) -> Self: ...

@overload
def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self: ...
def __rmatmul__(self, other: transform.TransformFunction) -> Self: ...

@override
def __rmatmul__(
self, other: transform.SupportsToMatrix
self, other: transform.TransformFunction
) -> Self | LineTo:
if isinstance(other, transform.Translate | transform.Scale):
_, y = other @ point.Point(0, self.y)
Expand All @@ -126,7 +126,7 @@ def __rmatmul__(
class SmoothQuadraticBezierTo(_HasEnd, _PhysicalPathCommand):
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(end=other @ self.end)


Expand All @@ -136,7 +136,7 @@ class SmoothCubicBezierTo(_HasEnd, _PhysicalPathCommand):
control2: point.Point
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(
control2=other @ self.control2, end=other @ self.end
)
Expand All @@ -147,7 +147,7 @@ def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
class MoveTo(_HasEnd, _PhysicalPathCommand):
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(end=other @ self.end)


Expand All @@ -157,7 +157,7 @@ class QuadraticBezierTo(_HasEnd, _PhysicalPathCommand):
control: point.Point
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(
control=other @ self.control, end=other @ self.end
)
Expand All @@ -170,7 +170,7 @@ class CubicBezierTo(_HasEnd, _PhysicalPathCommand):
control2: point.Point
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(
control1=other @ self.control1,
control2=other @ self.control2,
Expand All @@ -187,17 +187,29 @@ class ArcTo(_HasEnd, _PhysicalPathCommand):
sweep: bool
end: point.Point

def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
# right now, the radii are not transformed; this is incorrect
# TODO: figure out how to transform the radii
# (probably by converting to center parameterization, applying the
# transform, then converting back to endpoint parameterization)
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
radii = self.radii
angle = self.angle
end = self.end

match other:
case transform.Translate():
end = other @ end
case transform.Scale():
radii = other @ radii
end = other @ end
case transform.Rotate(a):
angle += a
case _:
msg = f"Unsupported transform: {type(other)}"
raise TypeError(msg)

return type(self)(
radii=self.radii,
angle=self.angle,
radii=radii,
angle=angle,
large=self.large,
sweep=self.sweep,
end=other @ self.end,
end=end,
)


Expand Down Expand Up @@ -899,7 +911,7 @@ def __len__(self) -> int:
return len(self.__commands)

@override
def __rmatmul__(self, other: transform.SupportsToMatrix) -> Self:
def __rmatmul__(self, other: transform.TransformFunction) -> Self:
return type(self)(
other @ command
for command in self
Expand Down
59 changes: 52 additions & 7 deletions svglab/attrparse/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Iterator

import lark
import numpy as np
import numpy.typing as npt
import pydantic
from typing_extensions import (
Annotated,
Expand All @@ -13,16 +15,17 @@
override,
)

from svglab import mixins, protocols, serialize, utils
from svglab.attrparse import parse
from svglab import mixins, protocols, serialize, utils, utiltypes
from svglab.attrparse import parse, transform


@pydantic.dataclasses.dataclass(frozen=True)
class _Point(
SupportsComplex,
mixins.FloatMulDiv,
mixins.AddSub["_Point"],
transform.PointAddSubWithTranslateRMatmul,
protocols.PointLike,
protocols.SupportsNpArray,
protocols.CustomSerializable,
):
"""A point in a 2D plane.
Expand Down Expand Up @@ -100,17 +103,59 @@ def line_reflect(self, center: Self) -> Self:
"""
return center + (center - self)

@override
@override
def __array__(
self, dtype: npt.DTypeLike = None, *, copy: bool | None = None
) -> utiltypes.NpFloatArray:
del dtype, copy

return np.array([self.x, self.y, 1])

@classmethod
def from_array(cls, array: utiltypes.NpFloatArray, /) -> Self:
"""Create a `Point` instance from a NumPy array.
The array must be a 3-element vector, representing a cartesian point
in the real projective plane using homogeneous coordinates.
The last element of the array must be non-zero (i.e., the point must
not be at infinity).
Args:
array: The array to convert to a point.
Returns:
The point represented by the array.
Raises:
ValueError: If the array is not a 3-element vector or if the last
element of the array is zero.
"""
if array.shape != (3,):
raise ValueError("The array must be a 3-element vector")

x, y, z = array

try:
return cls(x / z, y / z)
except ZeroDivisionError as e:
# if z is zero, the point is at infinity
raise ValueError(
"The last element of the array cannot be zero"
) from e

def __iter__(self) -> Iterator[float]:
return iter((self.x, self.y))

@override
def __add__(self, other: Self) -> Self:
return type(self)(self.x + other.x, self.y + other.y)

@override
def __mul__(self, scalar: float) -> Self:
return type(self)(self.x * scalar, self.y * scalar)

def __rmatmul__(self, other: protocols.SupportsNpArray) -> Self:
return self.from_array(np.array(other) @ np.array(self))

@override
def __eq__(self, other: object) -> bool:
if not utils.basic_compare(other, self=self):
Expand Down
Loading

0 comments on commit f069674

Please sign in to comment.