1
1
#!/usr/bin/env python3
2
2
import logging
3
3
from pathlib import Path
4
- from typing import Dict , Optional
4
+ from typing import Dict , Optional , Sequence
5
+ from collections import defaultdict
5
6
6
7
from pybind11_stubgen import Writer , QualifiedName , Printer , arg_parser , stub_parser_from_args , to_output_and_subdir , \
7
8
run
8
9
from pybind11_stubgen .structs import Function , ResolvedType , Module
9
10
10
11
11
12
class CustomWriter (Writer ):
12
- def __init__ (self , implicit_conversions : Dict [str , str ], stub_ext : str = "pyi" ):
13
+ def __init__ (self , alternative_types : Dict [str , Sequence [ str , ...] ], stub_ext : str = "pyi" ):
13
14
super ().__init__ (stub_ext = stub_ext )
14
- self .implicit_conversions = {
15
- QualifiedName .from_str (k ): QualifiedName .from_str (v ) for k , v in implicit_conversions .items ()
15
+ self .alternative_types = {
16
+ QualifiedName .from_str (k ): tuple ( QualifiedName .from_str (e ) for e in v ) for k , v in alternative_types .items ()
16
17
}
17
18
18
19
def _patch_function (self , function : Function ):
19
20
for argument in function .args :
20
- if argument .annotation is not None and argument .annotation .name in self .implicit_conversions :
21
- converted_type = ResolvedType (self .implicit_conversions [argument .annotation .name ])
21
+ if argument .annotation is not None and argument .annotation .name in self .alternative_types :
22
+ converted_types = [ ResolvedType (e ) for e in self .alternative_types [argument .annotation .name ]]
22
23
argument .annotation = ResolvedType (
23
- QualifiedName .from_str ("typing.Union" ), [argument .annotation , converted_type ] )
24
+ QualifiedName .from_str ("typing.Union" ), [argument .annotation ] + converted_types )
24
25
25
26
def write_module (self , module : Module , printer : Printer , to : Path , sub_dir : Optional [Path ] = None ):
26
27
for cls in module .classes :
@@ -32,11 +33,20 @@ def write_module(self, module: Module, printer: Printer, to: Path, sub_dir: Opti
32
33
super ().write_module (module , printer , to , sub_dir = sub_dir )
33
34
34
35
35
- IMPLICIT_CONVERSIONS = {
36
- "bool" : "Condition" ,
37
- "float" : "RelativeDynamicsFactor" ,
38
- "Affine" : "RobotPose" ,
39
- }
36
+ IMPLICIT_CONVERSIONS = [
37
+ ("bool" , "Condition" ),
38
+ ("float" , "RelativeDynamicsFactor" ),
39
+ ("Affine" , "RobotPose" ),
40
+ ("Twist" , "RobotVelocity" ),
41
+ ("RobotPose" , "CartesianState" ),
42
+ ("Affine" , "CartesianState" ),
43
+ ("list[float]" , "JointState" ),
44
+ ("np.ndarray" , "JointState" )
45
+ ]
46
+
47
+ alternatives = defaultdict (list )
48
+ for from_type , to_type in IMPLICIT_CONVERSIONS :
49
+ alternatives [to_type ].append (from_type )
40
50
41
51
if __name__ == "__main__" :
42
52
logging .basicConfig (
0 commit comments