Skip to content

Commit

Permalink
Add basic support for statically typing allowed children
Browse files Browse the repository at this point in the history
  • Loading branch information
reznakt committed Nov 21, 2024
1 parent b86d30a commit 79f149a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
18 changes: 10 additions & 8 deletions svglab/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# ruff: noqa: T201

from .elements import CData, Comment, G, Rect, Text
from .elements import CData, Comment, G, Rect, Svg, Text
from .io import parse_svg


def main() -> None:
soup = parse_svg("<foo></foo>")
print(soup.prettify())

group = (
group = Svg(
G()
.add_child(Rect())
.add_child(Comment("This is an example comment"))
Expand All @@ -18,12 +18,14 @@ def main() -> None:
)
print(group)

group2 = G(
Rect(),
Comment("This is an example comment"),
CData("foo { background-color: red; }"),
Text("baz"),
G(Rect(), Rect()),
group2 = Svg(
G(
Rect(),
Comment("This is an example comment"),
CData("foo { background-color: red; }"),
Text("baz"),
G(Rect(), Rect()),
)
)
print(group2)

Expand Down
43 changes: 28 additions & 15 deletions svglab/elements.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Hashable, Iterable
from contextlib import suppress
from typing import Final, Self, cast, final
from typing import Final, Self, Union, cast, final
from warnings import warn

import bs4
Expand All @@ -15,12 +15,17 @@
"Tag",
"Rect",
"G",
"Svg",
]

type AnyElement = Element[bs4.PageElement]
type Backend = bs4.PageElement
type TextBackend = bs4.Comment | bs4.CData | bs4.NavigableString

type AnyElement = Element[Backend]
type AnyTextElement = TextElement[TextBackend]

def backend_to_element(backend: bs4.PageElement) -> AnyElement | None:

def backend_to_element(backend: Backend) -> AnyElement | None:
match backend:
case bs4.Tag():
if backend.is_empty_element:
Expand All @@ -37,7 +42,7 @@ def backend_to_element(backend: bs4.PageElement) -> AnyElement | None:
return None


class Element[T: bs4.PageElement](Repr, Hashable, metaclass=ABCMeta):
class Element[T: Backend](Repr, Hashable, metaclass=ABCMeta):
def __init__(self, *, _backend: T | None = None) -> None:
self._backend = _backend if _backend is not None else self._default_backend

Expand Down Expand Up @@ -76,9 +81,7 @@ def parent(self) -> AnyElement | None:
return backend_to_element(parent) if parent is not None else None


class TextElement[T: bs4.Comment | bs4.CData | bs4.NavigableString](
Element[T], metaclass=ABCMeta
):
class TextElement[T: TextBackend](Element[T], metaclass=ABCMeta):
def __init__(
self,
content: str | None = None,
Expand Down Expand Up @@ -175,31 +178,34 @@ def _default_backend(self) -> bs4.Tag:
return bs4.Tag(name=self.name, can_be_empty_element=not self.paired)


class PairedTag(Tag, metaclass=ABCMeta):
class PairedTag[T: AnyElement](Tag, metaclass=ABCMeta):
paired = True

def __init__(self, *children: AnyElement, _backend: bs4.Tag | None = None) -> None:
def __init__(self, *children: T, _backend: bs4.Tag | None = None) -> None:
super().__init__(_backend=_backend)

for child in children:
self.add_child(child)

@property
def children(self) -> SizedIterable[AnyElement]:
def children(self) -> SizedIterable[T]:
return SizedIterable(self.__children())

def __children(self) -> Iterable[AnyElement]:
def __children(self) -> Iterable[T]:
for child in self._backend.children:
element = backend_to_element(child)

if element is not None:
yield element
# there is no way to statically ensure that the
# element is of the correct type, so we have to cast
# TODO: make sure this is correctly handled at runtime
yield cast(T, element)

def add_child(self, child: AnyElement) -> Self:
def add_child(self, child: T) -> Self:
self._backend.append(child.backend)
return self

def add_children(self, children: Iterable[AnyElement]) -> Self:
def add_children(self, children: Iterable[T]) -> Self:
# use extend() because it's faster than multiple calls to add_child()
self._backend.extend(child.backend for child in children)
return self
Expand All @@ -214,6 +220,13 @@ class Rect(UnpairedTag):
name: Final = "rect"


# use Union because the new syntax doesn't seem to work well
# with recursive types
@final
class G(PairedTag):
class G(PairedTag[Union[AnyTextElement, "G", Rect]]):
name: Final = "g"


@final
class Svg(PairedTag[AnyElement]):
name: Final = "svg"

0 comments on commit 79f149a

Please sign in to comment.