Skip to content

Commit fac29cc

Browse files
Fixed implicit conversion handling in stubgen
1 parent 750153d commit fac29cc

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

custom_stubgen.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
#!/usr/bin/env python3
22
import logging
33
from pathlib import Path
4-
from typing import Dict, Optional
4+
from typing import Dict, Optional, Sequence
5+
from collections import defaultdict
56

67
from pybind11_stubgen import Writer, QualifiedName, Printer, arg_parser, stub_parser_from_args, to_output_and_subdir, \
78
run
89
from pybind11_stubgen.structs import Function, ResolvedType, Module
910

1011

1112
class CustomWriter(Writer):
12-
def __init__(self, implicit_conversions: Dict[str, str], stub_ext: str = "pyi"):
13+
def __init__(self, alternative_types: Dict[str, Sequence[str, ...]], stub_ext: str = "pyi"):
1314
super().__init__(stub_ext=stub_ext)
14-
self.implicit_conversions = {
15-
QualifiedName.from_str(k): QualifiedName.from_str(v) for k, v in implicit_conversions.items()
15+
self.alternative_types = {
16+
QualifiedName.from_str(k): tuple(QualifiedName.from_str(e) for e in v) for k, v in alternative_types.items()
1617
}
1718

1819
def _patch_function(self, function: Function):
1920
for argument in function.args:
20-
if argument.annotation is not None and argument.annotation.name in self.implicit_conversions:
21-
converted_type = ResolvedType(self.implicit_conversions[argument.annotation.name])
21+
if argument.annotation is not None and argument.annotation.name in self.alternative_types:
22+
converted_types = [ResolvedType(e) for e in self.alternative_types[argument.annotation.name]]
2223
argument.annotation = ResolvedType(
23-
QualifiedName.from_str("typing.Union"), [argument.annotation, converted_type])
24+
QualifiedName.from_str("typing.Union"), [argument.annotation] + converted_types)
2425

2526
def write_module(self, module: Module, printer: Printer, to: Path, sub_dir: Optional[Path] = None):
2627
for cls in module.classes:
@@ -32,11 +33,20 @@ def write_module(self, module: Module, printer: Printer, to: Path, sub_dir: Opti
3233
super().write_module(module, printer, to, sub_dir=sub_dir)
3334

3435

35-
IMPLICIT_CONVERSIONS = {
36-
"bool": "Condition",
37-
"float": "RelativeDynamicsFactor",
38-
"Affine": "RobotPose",
39-
}
36+
IMPLICIT_CONVERSIONS = [
37+
("bool", "Condition"),
38+
("float", "RelativeDynamicsFactor"),
39+
("Affine", "RobotPose"),
40+
("Twist", "RobotVelocity"),
41+
("RobotPose", "CartesianState"),
42+
("Affine", "CartesianState"),
43+
("list[float]", "JointState"),
44+
("np.ndarray", "JointState")
45+
]
46+
47+
alternatives = defaultdict(list)
48+
for from_type, to_type in IMPLICIT_CONVERSIONS:
49+
alternatives[to_type].append(from_type)
4050

4151
if __name__ == "__main__":
4252
logging.basicConfig(

0 commit comments

Comments
 (0)