From 4ee330c2f21b020c29e791c814011875f4309f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=AD=20Zamora=20Casals?= Date: Wed, 18 Dec 2024 12:08:57 +0100 Subject: [PATCH] Fix types so that pyright detects decorated dataclasses properly --- jaxtyping/_decorator.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index ef876a5..a6372a4 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -33,7 +33,9 @@ get_type_hints, NoReturn, overload, + Type, TypeVar, + Union, ) @@ -73,17 +75,23 @@ def _apply_typechecker(typechecker, fn): return typechecker(fn) +_T = TypeVar("_T") + + @overload def jaxtyped( *, typechecker=_sentinel, -) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]: ... +) -> Union[ + Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]], + Callable[[Type[_T]], Type[_T]], +]: ... @overload def jaxtyped( - fn: Callable[_Params, _Return], *, typechecker=_sentinel -) -> Callable[_Params, _Return]: ... + fn: Union[Callable[_Params, _Return], Type[_T]], *, typechecker=_sentinel +) -> Union[Callable[_Params, _Return], Type[_T]]: ... def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):