Skip to content

Commit

Permalink
Switch to use ruff for formatting
Browse files Browse the repository at this point in the history
Remove autoflake, isort, and black from dependencies
  • Loading branch information
yugokato committed Feb 7, 2025
1 parent 1320280 commit 5c7e2b7
Show file tree
Hide file tree
Showing 17 changed files with 109 additions and 128 deletions.
28 changes: 3 additions & 25 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,8 @@ repos:
args:
- --fix
types: [ python ]
# TODO: Replace autoflake, isort, and black with ruff once ruff supports public APIs (https://github.com/astral-sh/ruff/issues/659)
- id: autoflake
name: autoflake
entry: autoflake
- id: ruff-format
name: ruff-format
entry: ruff format
language: system
args:
- --recursive
- --check-diff
- --remove-all-unused-imports
- --ignore-init-module-imports
- --ignore-pass-statements
- --quiet
types: [ python ]
- id: isort
name: isort
entry: isort
language: system
args:
- --check
types: [ python ]
- id: black
name: black
entry: black
language: system
args:
- --check
types: [ python ]
13 changes: 1 addition & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ classifiers = [
"Topic :: Software Development :: Testing",
]
dependencies = [
"autoflake==2.3.1",
"black==23.12.1",
"common-libs[client]@git+https://github.com/yugokato/common-libs",
"isort==5.13.2",
"inflect==7.0.0",
"phonenumbers==8.13.45",
"pydantic-extra-types==2.9.0",
Expand Down Expand Up @@ -62,15 +59,6 @@ local_scheme = "no-local-version"
python_files = ["test_*.py"]
testpaths = "tests"

[tool.isort]
line_length = 120
multi_line_output = 3
include_trailing_comma = true
profile = "black"

[tool.black]
line_length = 120

[tool.ruff]
line-length = 120
indent-width = 4
Expand All @@ -79,6 +67,7 @@ indent-width = 4
select = [
"E", # pycodestyle
"F", # Pyflakes
"I", # isort
"UP", # pyupgrade
]
ignore = ["E731", "E741", "F403"]
Expand Down
3 changes: 1 addition & 2 deletions src/demo_app/handlers/error_handlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
from dataclasses import dataclass

from quart import Blueprint, Response
from quart import Blueprint, Response, jsonify, make_response, request
from quart import current_app as app
from quart import jsonify, make_response, request
from quart_schema import RequestSchemaValidationError
from werkzeug.exceptions import NotFound

Expand Down
12 changes: 6 additions & 6 deletions src/openapi_test_client/libraries/api/api_classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def init_api_classes(base_api_class: type[APIClassType]) -> list[type[APIClassTy

previous_frame = inspect.currentframe().f_back
caller_file_path = inspect.getframeinfo(previous_frame).filename
assert caller_file_path.endswith(
"__init__.py"
), f"API classes must be initialized in __init__.py. Unexpectedly called from {caller_file_path}"
assert caller_file_path.endswith("__init__.py"), (
f"API classes must be initialized in __init__.py. Unexpectedly called from {caller_file_path}"
)

# Set each API class's available Endpoint objects to its endpoints attribute
api_classes = get_api_classes(Path(caller_file_path).parent, base_api_class)
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_api_classes(api_class_dir: Path, base_api_class: type[APIClassType]) ->
and issubclass(getattr(mod, x), base_api_class)
and x != base_api_class.__name__
]
assert (
api_classes
), f"Unable to find any API class that is a subclass of {base_api_class.__name__} in {api_class_dir}"
assert api_classes, (
f"Unable to find any API class that is a subclass of {base_api_class.__name__} in {api_class_dir}"
)
return api_classes
33 changes: 13 additions & 20 deletions src/openapi_test_client/libraries/api/api_client_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,11 @@ def update_endpoint_functions(
# endpoint path and endpoint options
rf"(\n{tab}{{2}})?\"(?P<path>.+?)\"(?P<ep_options>,.+?)?(\n{tab})?\)\n"
# function def
rf"(?P<func_def>{tab}def (?P<func_name>.+?)\((?P<signature>.+?){tab}?\) -> {RestResponse.__name__}:\n)"
rf"(?P<func_def>{tab}def (?P<func_name>.+?)\((?P<signature>.+?){tab}?\) -> {RestResponse.__name__}:\n?)"
# docstring
rf"({tab}{{2}}(?P<docstring>\"{{3}}.*?\"{{3}})\n)?"
# function body
rf"(?P<func_body>\n*{tab}{{2}}(?:[^@]+|\.{{3}})\n)?$",
rf"(?P<func_body>(?:\n*{tab}{{2}}(?:[^@]+|\.{{3}})| \.{{3}})\n)?$",
flags=re.MULTILINE | re.DOTALL,
)

Expand Down Expand Up @@ -324,12 +324,12 @@ def update_existing_endpoints(target_api_class: type[APIClassType] = api_class):
endpoint_str = f"{method.upper()} {path}"
defined_endpoints.append((method, path))

# For troubleshooting
# # For troubleshooting
# print(
# f"{method.upper()} {path}:\n"
# f" - matched: {repr(matched.group(0))}\n"
# f" - decorators: {repr(decorators)}\n"
# f" - func_def: {repr(matched.group("func_def"))}\n"
# f" - decorators: {repr(matched.group('decorators'))}\n"
# f" - func_def: {repr(func_def)}\n"
# f" - func_name: {repr(func_name)}\n"
# f" - signature: {repr(signature)}\n"
# f" - docstring: {repr(docstring)}\n"
Expand Down Expand Up @@ -373,8 +373,7 @@ def update_existing_endpoints(target_api_class: type[APIClassType] = api_class):

# Collect all param models for this endpoint
param_models.extend(param_model_util.get_param_models(endpoint_model))
# Fill missing imports (typing and custom param model classes). Duplicates will be removed by black at
# the end
# Fill missing imports (typing and custom param model classes). Duplicates will be removed at the end
if missing_imports_code := param_model_util.generate_imports_code_from_model(api_class, endpoint_model):
new_code = missing_imports_code + new_code

Expand All @@ -387,8 +386,10 @@ def update_existing_endpoints(target_api_class: type[APIClassType] = api_class):
updated_api_func_code = updated_api_func_code.replace(docstring, expected_docstring)
else:
updated_api_func_code = updated_api_func_code.replace(
func_def, func_def + f"{TAB * 2}{expected_docstring}\n"
func_def, func_def + f"\n{TAB * 2}{expected_docstring}\n"
)
if func_body == " ...\n":
updated_api_func_code = updated_api_func_code.replace(func_body, f"{TAB * 2}...\n")

# Update API function signatures
new_func_signature = endpoint_model_util.generate_func_signature_in_str(endpoint_model).replace(
Expand Down Expand Up @@ -541,8 +542,7 @@ def update_missing_endpoints():

if param_models:
modified_model_code = (
f"from dataclasses import dataclass\n\n"
f"from {ParamModel.__module__} import {ParamModel.__name__}\n\n"
f"from dataclasses import dataclass\n\nfrom {ParamModel.__module__} import {ParamModel.__name__}\n\n"
)
for model in param_model_util.sort_by_dependency(param_model_util.dedup_models_by_name(param_models)):
imports_code, model_code = param_model_util.generate_model_code_from_model(api_class, model)
Expand Down Expand Up @@ -605,8 +605,7 @@ def generate_api_client(temp_api_client: OpenAPIClient, show_generated_code: boo
api_client_class_name = f"{api_client_class_name_part}{API_CLIENT_CLASS_NAME_SUFFIX}"

imports_code = (
f"from functools import cached_property\n\n"
f"from {OpenAPIClient.__module__} import {OpenAPIClient.__name__}\n"
f"from functools import cached_property\n\nfrom {OpenAPIClient.__module__} import {OpenAPIClient.__name__}\n"
)
api_client_code = (
f"class {api_client_class_name}({OpenAPIClient.__name__}):\n"
Expand All @@ -624,9 +623,7 @@ def generate_api_client(temp_api_client: OpenAPIClient, show_generated_code: boo
imports_code += f"from .{API_CLASS_DIR_NAME}.{Path(mod.__file__).stem} import {api_class.__name__}\n"
property_name = camel_to_snake(api_class.__name__.removesuffix("API")).upper()
api_client_code += (
f"{TAB}@cached_property\n"
f"{TAB}def {property_name}(self):\n"
f"{TAB}{TAB}return {api_class.__name__}(self)\n\n"
f"{TAB}@cached_property\n{TAB}def {property_name}(self):\n{TAB}{TAB}return {api_class.__name__}(self)\n\n"
)

code = format_code(imports_code + api_client_code)
Expand Down Expand Up @@ -664,11 +661,7 @@ def setup_external_directory(client_name: str, base_url: str, env: str = DEFAULT
api_client_lib_dir = get_package_dir()
api_client_lib_dir.mkdir(parents=True, exist_ok=True)
# Add __init__.py
code = (
f"import os\n"
f"from pathlib import Path\n\n"
f'os.environ["{ENV_VAR_PACKAGE_DIR}"] = str(Path(__file__).parent)'
)
code = f'import os\nfrom pathlib import Path\n\nos.environ["{ENV_VAR_PACKAGE_DIR}"] = str(Path(__file__).parent)'
_write_init_file(api_client_lib_dir, format_code(DO_NOT_DELETE_COMMENT + code))

# Add a hidden file to the package directory so that we can locate this directory later
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def __init__(self, endpoint_handler: EndpointHandler, instance: APIClassType | N
self.path,
self._original_func.__name__,
self.model,
url=f"{self.rest_client.url_base}{self.path }" if instance else None,
url=f"{self.rest_client.url_base}{self.path}" if instance else None,
content_type=endpoint_handler.content_type,
is_public=endpoint_handler.is_public,
is_documented=owner.is_documented and endpoint_handler.is_documented,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
"""
path_param_fields = []
body_or_query_param_fields = []
model_name = f'{type(endpoint_func).__name__.replace("EndpointFunc", EndpointModel.__name__)}'
model_name = f"{type(endpoint_func).__name__.replace('EndpointFunc', EndpointModel.__name__)}"
content_type = None
if api_spec:
# Generate model fields from the OpenAPI spec. See https://swagger.io/specification/ for the specification
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
from functools import reduce
from operator import or_
from types import NoneType, UnionType
from typing import _AnnotatedAlias # noqa
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin
from typing import (
Annotated,
Any,
Literal,
Optional,
Union,
_AnnotatedAlias, # noqa
get_args,
get_origin,
)

from common_libs.logging import get_logger

Expand Down Expand Up @@ -56,7 +64,7 @@ def get_type_annotation_as_str(tp: Any) -> str:
return f"{type(tp).__name__}({repr(tp.value)})"
elif isinstance(tp, Constraint):
const = ", ".join(
f'{k}={("r" + repr(v).replace(BACKSLASH * 2, BACKSLASH) if k == "pattern" else repr(v))}'
f"{k}={('r' + repr(v).replace(BACKSLASH * 2, BACKSLASH) if k == 'pattern' else repr(v))}"
for k, v in asdict(tp).items()
if v is not None
)
Expand Down
4 changes: 2 additions & 2 deletions src/openapi_test_client/libraries/api/api_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_endpoint_usage(self, endpoint: Endpoint) -> str | None:
summary = ep_doc.get("summary")
parameters = ep_doc.get("parameters")
request_body = ep_doc.get("requestBody")
usage = f"- Method: {method.upper()}\n" f"- Path: {path}\n- Summary: {summary}\n"
usage = f"- Method: {method.upper()}\n- Path: {path}\n- Summary: {summary}\n"
if parameters:
usage += f"- Parameters: {json.dumps(parameters, indent=4)}\n"
if request_body:
Expand All @@ -97,7 +97,7 @@ def parse(api_spec: dict[str, Any]) -> dict[str, Any]:
if tags := parsed_spec.get("tags"):
if undefined_endpoint_tags := set(endpoint_tags).difference(set([t["name"] for t in tags])):
logger.warning(
f'One ore more endpoint tags are not defined at the top-level "tags": ' f"{undefined_endpoint_tags}"
f'One ore more endpoint tags are not defined at the top-level "tags": {undefined_endpoint_tags}'
)
else:
# We need the top-level "tags" but it is either not defined or empty.
Expand Down
38 changes: 27 additions & 11 deletions src/openapi_test_client/libraries/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@

import json
from collections.abc import Callable, Mapping, Sequence
from dataclasses import _DataclassParams # noqa
from dataclasses import MISSING, Field, asdict, astuple, dataclass, field, is_dataclass, make_dataclass
from dataclasses import (
MISSING,
Field,
_DataclassParams, # noqa
asdict,
astuple,
dataclass,
field,
is_dataclass,
make_dataclass,
)
from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast

Expand Down Expand Up @@ -100,20 +109,17 @@ class ParamGroup(tuple):
def is_required(self) -> bool:
return any(p.is_required for p in self)

class OneOf(ParamGroup):
...
class OneOf(ParamGroup): ...

class AnyOf(ParamGroup):
...
class AnyOf(ParamGroup): ...

class AllOf(ParamGroup):
...
class AllOf(ParamGroup): ...

@staticmethod
@freeze_args
@lru_cache
def from_param_obj(
param_obj: Mapping[str, Any] | dict[str, Any] | Sequence[dict[str, Any]]
param_obj: Mapping[str, Any] | dict[str, Any] | Sequence[dict[str, Any]],
) -> ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType:
"""Convert the parameter object to a ParamDef"""

Expand All @@ -131,7 +137,15 @@ def convert(obj: Any):
convert(p)
for p in obj["allOf"]
if any(
key in p.keys() for key in ["oneOf", "anyOf", "allOf", "schema", "type", "properties"]
key in p.keys()
for key in [
"oneOf",
"anyOf",
"allOf",
"schema",
"type",
"properties",
]
)
]
)
Expand Down Expand Up @@ -378,7 +392,9 @@ def setdefault(self, key: str, default: Any = None) -> Any:

@classmethod
def recreate(
cls, current_class: type[ParamModel], new_fields: list[tuple[str, Any, Field | None]]
cls,
current_class: type[ParamModel],
new_fields: list[tuple[str, Any, Field | None]],
) -> type[ParamModel]:
"""Recreate the model with the new fields
Expand Down
Loading

0 comments on commit 5c7e2b7

Please sign in to comment.