|
1 | 1 | #!/usr/bin/env python3
|
2 |
| -import pybind11_stubgen |
3 |
| -import re |
| 2 | +from typing import Dict |
4 | 3 |
|
| 4 | +from pybind11_stubgen import * |
| 5 | +from pybind11_stubgen.structs import * |
5 | 6 |
|
6 |
| -def add_union(alternative_type: str): |
7 |
| - def inner(match, _alternative_type=alternative_type): |
8 |
| - return "typing.Union[{}, {}]".format(match.group(0), _alternative_type) |
9 | 7 |
|
10 |
| - return inner |
| 8 | +class CustomWriter(Writer): |
| 9 | + def __init__(self, implicit_conversions: Dict[str, str], stub_ext: str = "pyi"): |
| 10 | + super().__init__(stub_ext=stub_ext) |
| 11 | + self.implicit_conversions = { |
| 12 | + QualifiedName.from_str(k): QualifiedName.from_str(v) for k, v in implicit_conversions.items() |
| 13 | + } |
11 | 14 |
|
| 15 | + def _patch_function(self, function: Function): |
| 16 | + for argument in function.args: |
| 17 | + if argument.annotation is not None and argument.annotation.name in self.implicit_conversions: |
| 18 | + converted_type = ResolvedType(self.implicit_conversions[argument.annotation.name]) |
| 19 | + argument.annotation = ResolvedType( |
| 20 | + QualifiedName.from_str("typing.Union"), [argument.annotation, converted_type]) |
12 | 21 |
|
13 |
| -if __name__ == '__main__': |
14 |
| - implicit_conversions = { |
15 |
| - "bool": "Condition", |
16 |
| - "float": "RelativeDynamicsFactor", |
17 |
| - "Affine": "RobotPose", |
18 |
| - } |
| 22 | + def write_module(self, module: Module, printer: Printer, to: Path, sub_dir: Optional[Path] = None): |
| 23 | + for cls in module.classes: |
| 24 | + for method in cls.methods: |
| 25 | + self._patch_function(method.function) |
| 26 | + for prop in cls.properties: |
| 27 | + if prop.setter is not None: |
| 28 | + self._patch_function(prop.setter) |
| 29 | + super().write_module(module, printer, to, sub_dir=sub_dir) |
19 | 30 |
|
20 |
| - pybind11_stubgen.StubsGenerator.GLOBAL_CLASSNAME_REPLACEMENTS.update({ |
21 |
| - re.compile("({})".format(orig_type)): add_union(alt_type) |
22 |
| - for alt_type, orig_type in implicit_conversions.items() |
23 |
| - }) |
24 | 31 |
|
25 |
| - pybind11_stubgen.main() |
| 32 | +IMPLICIT_CONVERSIONS = { |
| 33 | + "bool": "Condition", |
| 34 | + "float": "RelativeDynamicsFactor", |
| 35 | + "Affine": "RobotPose", |
| 36 | +} |
| 37 | + |
| 38 | +if __name__ == "__main__": |
| 39 | + logging.basicConfig( |
| 40 | + level=logging.INFO, |
| 41 | + format="%(name)s - [%(levelname)7s] %(message)s", |
| 42 | + ) |
| 43 | + args = arg_parser().parse_args() |
| 44 | + |
| 45 | + parser = stub_parser_from_args(args) |
| 46 | + |
| 47 | + printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is) |
| 48 | + |
| 49 | + out_dir, sub_dir = to_output_and_subdir( |
| 50 | + output_dir=args.output_dir, |
| 51 | + module_name=args.module_name, |
| 52 | + root_suffix=args.root_suffix, |
| 53 | + ) |
| 54 | + |
| 55 | + run( |
| 56 | + parser, |
| 57 | + printer, |
| 58 | + args.module_name, |
| 59 | + out_dir, |
| 60 | + sub_dir=sub_dir, |
| 61 | + dry_run=args.dry_run, |
| 62 | + writer=CustomWriter(IMPLICIT_CONVERSIONS, stub_ext=args.stub_extension), |
| 63 | + ) |
0 commit comments