Skip to content

Commit

Permalink
Refactor endpoint model creation
Browse files Browse the repository at this point in the history
  • Loading branch information
yugokato committed Mar 7, 2025
1 parent 13c6619 commit 52c89ae
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
import inspect
import json
import re
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import MISSING, Field, field, make_dataclass
from dataclasses import MISSING, field, make_dataclass
from typing import TYPE_CHECKING, Any, cast

from common_libs.logging import get_logger

from openapi_test_client.libraries.api.api_functions.utils import param_model as param_model_util
from openapi_test_client.libraries.api.api_functions.utils import param_type as param_type_util
from openapi_test_client.libraries.api.types import EndpointModel, File, ParamDef, Unset
from openapi_test_client.libraries.api.types import (
DataclassModelField,
EndpointModel,
File,
ParamDef,
Unset,
)

if TYPE_CHECKING:
from openapi_test_client.libraries.api import EndpointFunc
Expand All @@ -27,8 +34,8 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
:param api_spec: Create a model from the OpenAPI spec. Otherwise the model be created from the existing endpoint
function signatures
"""
path_param_fields = []
body_or_query_param_fields = []
path_param_fields: list[DataclassModelField] = []
body_or_query_param_fields: list[DataclassModelField] = []
model_name = f"{type(endpoint_func).__name__.replace('EndpointFunc', EndpointModel.__name__)}"
content_type = None
if api_spec:
Expand All @@ -48,11 +55,10 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
continue
elif param_obj.default == inspect.Parameter.empty:
# Positional arguments (path parameters)
path_param_fields.append((name, param_obj.annotation))
path_param_fields.append(DataclassModelField(name, param_obj.annotation))
else:
# keyword arguments (body/query parameters)
param_field = (name, param_obj.annotation, field(default=Unset))
body_or_query_param_fields.append(param_field)
_add_body_or_query_param_field(body_or_query_param_fields, name, param_obj.annotation)

if hasattr(endpoint_func, "endpoint"):
method = endpoint_func.endpoint.method
Expand All @@ -64,13 +70,13 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
# body/query params if any
path_param_fields = [(x, str) for x in expected_path_params]
path_param_fields = [DataclassModelField(x, str) for x in expected_path_params]
body_or_query_param_fields = [x for x in body_or_query_param_fields if x[0] not in expected_path_params]

# Address the case where a path param name conflicts with body/query param name
for i, (field_name, field_type) in enumerate(path_param_fields):
for i, (field_name, field_type, _) in enumerate(path_param_fields):
if field_name in [x[0] for x in body_or_query_param_fields]:
path_param_fields[i] = (f"{field_name}_", field_type)
path_param_fields[i] = DataclassModelField(f"{field_name}_", field_type)

# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
Expand All @@ -83,7 +89,7 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
type[EndpointModel],
make_dataclass(
model_name,
fields,
fields, # type: ignore
bases=(EndpointModel,),
namespace={"content_type": content_type, "endpoint_func": endpoint_func},
kw_only=True,
Expand Down Expand Up @@ -130,8 +136,8 @@ def generate_func_signature_in_str(model: type[EndpointModel]) -> str:
def _parse_parameter_objects(
method: str,
parameter_objects: list[dict[str, Any]],
path_param_fields: list[tuple[str, Any]],
body_or_query_param_fields: list[tuple[str, Any, Field]],
path_param_fields: list[DataclassModelField],
body_or_query_param_fields: list[DataclassModelField],
):
"""Parse parameter objects
Expand All @@ -148,12 +154,13 @@ def _parse_parameter_objects(
param_type_annotation = param_type_util.resolve_type_annotation(
param_name, param_def, _is_required=is_required
)

if param_location in ["header", "cookies"]:
# We currently don't support these
continue
elif param_location == "path":
path_param_fields.append((param_name, param_type_annotation))
if param_name not in [x[0] for x in path_param_fields]:
# Handle duplicates. Some API specs incorrectly document duplicated parameters
path_param_fields.append(DataclassModelField(param_name, param_type_annotation))
elif param_location == "query":
if method.upper() != "GET":
# Annotate query params for non GET endpoints
Expand All @@ -180,19 +187,14 @@ def _parse_parameter_objects(
method, parameter_objects, path_param_fields, body_or_query_param_fields
)
else:
if param_name not in [x[0] for x in body_or_query_param_fields]:
body_or_query_param_fields.append(
(
param_name,
param_type_annotation,
field(default=Unset, metadata=param_obj),
)
)
else:
if param_name not in [x[0] for x in body_or_query_param_fields]:
body_or_query_param_fields.append(
(param_name, param_type_annotation, field(default=Unset, metadata=param_obj))
_add_body_or_query_param_field(
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
)

else:
_add_body_or_query_param_field(
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
)
else:
raise NotImplementedError(f"Unsupported param 'in': {param_location}")
except Exception:
Expand All @@ -205,7 +207,7 @@ def _parse_parameter_objects(


def _parse_request_body_object(
request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Field]]
request_body_obj: dict[str, Any], body_or_query_param_fields: list[DataclassModelField]
) -> str | None:
"""Parse request body object
Expand Down Expand Up @@ -250,7 +252,9 @@ def parse_schema_obj(obj: dict[str, Any]):
param_type = File
if not param_def.is_required:
param_type = param_type | None
body_or_query_param_fields.append((param_name, param_type, field(default=Unset)))
_add_body_or_query_param_field(
body_or_query_param_fields, param_name, param_type, param_obj=param_obj
)
else:
existing_param_names = [x[0] for x in body_or_query_param_fields]
if param_name in existing_param_names:
Expand All @@ -259,16 +263,17 @@ def parse_schema_obj(obj: dict[str, Any]):
for _, t, m in duplicated_param_fields:
param_type_annotations.append(t)
param_type_annotation = param_type_util.generate_union_type(param_type_annotations)
merged_param_field = (
merged_param_field = DataclassModelField(
param_name,
param_type_annotation,
field(default=Unset, metadata=param_obj),
default=field(default=Unset, metadata=param_obj),
)
body_or_query_param_fields[existing_param_names.index(param_name)] = merged_param_field
else:
param_type_annotation = param_type_util.resolve_type_annotation(param_name, param_def)
param_field = (param_name, param_type_annotation, field(default=Unset, metadata=param_obj))
body_or_query_param_fields.append(param_field)
_add_body_or_query_param_field(
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
)
except Exception:
logger.error(
"Encountered an error while processing the param object in 'requestBody':\n"
Expand All @@ -283,6 +288,18 @@ def parse_schema_obj(obj: dict[str, Any]):
return content_type


def _add_body_or_query_param_field(
param_fields: list[DataclassModelField],
param_name: str,
param_type_annotation: Any,
param_obj: Mapping[str, Any] | dict[str, Any] | Sequence[dict[str, Any]] | None = None,
):
if param_name not in [x[0] for x in param_fields]:
param_fields.append(
DataclassModelField(param_name, param_type_annotation, default=field(default=Unset, metadata=param_obj))
)


def _is_file_param(
content_type: str,
param_def: ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from openapi_test_client.libraries.api.types import (
Alias,
DataclassModel,
DataclassModelField,
EndpointModel,
File,
ParamAnnotationType,
Expand Down Expand Up @@ -117,15 +118,22 @@ def create_model_from_param_def(
return _merge_models([create_model_from_param_def(model_name, p) for p in param_def])
else:
fields = [
(
DataclassModelField(
inner_param_name,
param_type_util.resolve_type_annotation(inner_param_name, ParamDef.from_param_obj(inner_param_obj)),
field(default=Unset, metadata=inner_param_obj),
default=field(default=Unset, metadata=inner_param_obj),
)
for inner_param_name, inner_param_obj in param_def.get("properties", {}).items()
]
alias_illegal_model_field_names(fields)
return cast(type[ParamModel], make_dataclass(model_name, fields, bases=(ParamModel,)))
return cast(
type[ParamModel],
make_dataclass(
model_name,
fields, # type: ignore
bases=(ParamModel,),
),
)


def generate_imports_code_from_model(
Expand Down Expand Up @@ -286,10 +294,10 @@ def visit(model_name: str):
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))


def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Field]]):
def alias_illegal_model_field_names(model_fields: list[DataclassModelField]):
"""Clean illegal model field name and annotate the field type with Alias class
:param param_fields: fields value to be passed to make_dataclass()
:param model_fields: fields value to be passed to make_dataclass()
"""

def make_alias(name: str, param_type: Any) -> str:
Expand Down Expand Up @@ -330,22 +338,17 @@ def make_alias(name: str, param_type: Any) -> str:
name += "_"
return name

if param_fields:
for i, param_field in enumerate(param_fields):
if len(param_field) == 2:
# path parameters
field_name, field_type = param_field
field_obj = object
else:
# body or query parameters
field_name, field_type, field_obj = param_field

if (alias_name := make_alias(field_name, field_type)) != field_name:
if isinstance(field_obj, Field) and field_obj.metadata:
logger.warning(f"Converted parameter name '{field_name}' to '{alias_name}'")
new_fields = [alias_name, param_type_util.generate_annotated_type(field_type, Alias(field_name))]
new_fields.append(field_obj)
param_fields[i] = tuple(new_fields)
if model_fields:
for i, model_field in enumerate(model_fields):
if (alias_name := make_alias(model_field.name, model_field.type)) != model_field.name:
if isinstance(model_field.default, Field) and model_field.default.metadata:
logger.warning(f"Converted parameter name '{model_field.name}' to '{alias_name}'")
new_fields = (
alias_name,
param_type_util.generate_annotated_type(model_field.type, Alias(model_field.name)),
model_field.default,
)
model_fields[i] = DataclassModelField(*new_fields)


def _merge_models(models: list[type[ParamModel]]) -> type[ParamModel]:
Expand Down
2 changes: 1 addition & 1 deletion src/openapi_test_client/libraries/api/api_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def has_reference(obj: Any) -> bool:
return "'$ref':" in str(obj)

def resolve_recursive(reference: Any, schemas_seen: list[str] | None = None):
if not schemas_seen:
if schemas_seen is None:
schemas_seen = []
if isinstance(reference, dict):
for k, v in copy.deepcopy(reference).items():
Expand Down
10 changes: 9 additions & 1 deletion src/openapi_test_client/libraries/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
make_dataclass,
)
from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TypeVar, cast

from common_libs.decorators import freeze_args
from common_libs.hash import HashableDict
Expand Down Expand Up @@ -219,6 +219,14 @@ def to_pydantic(cls) -> type[PydanticModel]:
)


class DataclassModelField(NamedTuple):
"""Dataclass model field"""

name: str
type: Any
default: Field | type[MISSING] = MISSING


class EndpointModel(DataclassModel):
content_type: str | None
endpoint_func: EndpointFunc
Expand Down

0 comments on commit 52c89ae

Please sign in to comment.