Skip to content

Commit

Permalink
Added enum support to Measurement Client (#880)
Browse files Browse the repository at this point in the history
* Feat: Add enum support to Measurement Client

* Fix: Invalid chars update

* Revert: Wrong commit files

* Fix: Mypy errors

* Fix: More Mypy errors

* Tests: Add enum param in test measurement

* Fix: Merge errors

* Fix: Lint errors

* Revert: Version change

* Revert: Unused imports

* Fix: Add ienum formatting in black

* Refractor: Use Dict get()

* Fix: Minor refractoring

* Client: Cache Enum type

* Client: Update type hints

* Fix: Refractor Pascal case conversion

* Tests: Update tests for enum parameters

* Test: Assest file changes

* Fix: Mypy error on enum dynamic creation

* Fix: Remove main function

* Fix: Lint errors

* Client: Supress MyPy error

* Client: Remove GetMetadata in template file

* Client: Remove unused imports

* Tests: Synchronize test measurement client

* Fix: Spacing

* Client: Intialize enum_values_by_type for each measurement
  • Loading branch information
MounikaBattu17 authored Sep 17, 2024
1 parent 0462410 commit 09b40b0
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Utilizes command line args to create a Measurement Plug-In Client using template files."""

import pathlib
from typing import Any, List, Optional
import re
from enum import Enum
from typing import Any, Dict, List, Optional, Type

import black
import click
Expand Down Expand Up @@ -32,12 +34,18 @@ def _render_template(template_name: str, **template_args: Any) -> bytes:
return template.render(**template_args)


def _replace_enum_class_type(output: str) -> str:
pattern = "<enum '([^']+)'>"
return re.sub(pattern, r"\1", output)


def _create_file(
template_name: str, file_name: str, directory_out: pathlib.Path, **template_args: Any
) -> None:
output_file = directory_out / file_name

output = _render_template(template_name, **template_args).decode("utf-8")
output = _replace_enum_class_type(output)
formatted_output = black.format_str(
src_contents=output,
mode=black.Mode(line_length=100),
Expand Down Expand Up @@ -109,6 +117,7 @@ def create_client(

is_multiple_client_generation = len(measurement_service_class) > 1
for service_class in measurement_service_class:
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {}
if is_multiple_client_generation or module_name is None or class_name is None:
base_service_class = service_class.split(".")[-1]
base_service_class = remove_suffix(base_service_class)
Expand Down Expand Up @@ -146,16 +155,21 @@ def create_client(
metadata = measurement_service_stub.GetMetadata(
v2_measurement_service_pb2.GetMetadataRequest()
)
configuration_metadata = get_configuration_metadata_by_index(metadata, service_class)
output_metadata = get_output_metadata_by_index(metadata)
configuration_metadata = get_configuration_metadata_by_index(
metadata, service_class, enum_values_by_type
)
output_metadata = get_output_metadata_by_index(metadata, enum_values_by_type)

configuration_parameters_with_type_and_default_values, measure_api_parameters = (
get_configuration_parameters_with_type_and_default_values(
configuration_metadata, built_in_import_modules
configuration_metadata, built_in_import_modules, enum_values_by_type
)
)
output_parameters_with_type = get_output_parameters_with_type(
output_metadata, built_in_import_modules, custom_import_modules
output_metadata,
built_in_import_modules,
custom_import_modules,
enum_values_by_type,
)

_create_file(
Expand All @@ -172,6 +186,7 @@ def create_client(
output_parameters_with_type=output_parameters_with_type,
built_in_import_modules=to_ordered_set(built_in_import_modules),
custom_import_modules=to_ordered_set(custom_import_modules),
enum_by_class_name=enum_values_by_type,
)

print(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Support functions for the Measurement Plug-In Client generator."""

import json
import keyword
import os
import re
import sys
from typing import AbstractSet, Dict, Iterable, List, Optional, Tuple, TypeVar
from enum import Enum
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar

import click
import grpc
from google.protobuf import descriptor_pool
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
from google.protobuf.type_pb2 import Field
from ni_measurement_plugin_sdk_service._internal.grpc_servicer import frame_metadata_dict
from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v2 import (
Expand Down Expand Up @@ -94,7 +97,9 @@ def get_all_registered_measurement_service_classes(discovery_client: DiscoveryCl


def get_configuration_metadata_by_index(
metadata: v2_measurement_service_pb2.GetMetadataResponse, service_class: str
metadata: v2_measurement_service_pb2.GetMetadataResponse,
service_class: str,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Dict[int, ParameterMetadata]:
"""Returns the configuration metadata of the measurement."""
configuration_parameter_list = []
Expand All @@ -107,6 +112,11 @@ def get_configuration_metadata_by_index(
default_value=None,
annotations=dict(configuration.annotations.items()),
message_type=configuration.message_type,
enum_type=(
_get_enum_type(configuration, enum_values_by_type)
if _is_enum_param(configuration.type)
else None
),
)
)

Expand All @@ -120,6 +130,11 @@ def get_configuration_metadata_by_index(
default_value=None,
annotations=dict(output.annotations.items()),
message_type=output.message_type,
enum_type=(
_get_enum_type(output, enum_values_by_type)
if _is_enum_param(output.type)
else None
),
)
)

Expand All @@ -137,13 +152,21 @@ def get_configuration_metadata_by_index(
)

for k, v in deserialized_parameters.items():
configuration_metadata[k] = configuration_metadata[k]._replace(default_value=v)
if issubclass(type(v), Enum):
default_value = v.value
elif issubclass(type(v), list) and any(issubclass(type(e), Enum) for e in v):
default_value = [e.value for e in v]
else:
default_value = v

configuration_metadata[k] = configuration_metadata[k]._replace(default_value=default_value)

return configuration_metadata


def get_output_metadata_by_index(
metadata: v2_measurement_service_pb2.GetMetadataResponse,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Dict[int, ParameterMetadata]:
"""Returns the output metadata of the measurement."""
output_parameter_list = []
Expand All @@ -156,6 +179,11 @@ def get_output_metadata_by_index(
default_value=None,
annotations=dict(output.annotations.items()),
message_type=output.message_type,
enum_type=(
_get_enum_type(output, enum_values_by_type)
if _is_enum_param(output.type)
else None
),
)
)
output_metadata = frame_metadata_dict(output_parameter_list)
Expand All @@ -165,6 +193,7 @@ def get_output_metadata_by_index(
def get_configuration_parameters_with_type_and_default_values(
configuration_metadata: Dict[int, ParameterMetadata],
built_in_import_modules: List[str],
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Tuple[str, str]:
"""Returns configuration parameters of the measurement with type and default values."""
configuration_parameters = []
Expand All @@ -189,6 +218,23 @@ def get_configuration_parameters_with_type_and_default_values(
else:
default_value = f"Path({default_value})"

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type = _get_enum_type(metadata, enum_values_by_type)
parameter_type = enum_type.__name__
if metadata.repeated:
values = []
for val in default_value:
enum_value = next((e.name for e in enum_type if e.value == val), None)
values.append(f"{parameter_type}.{enum_value}")
concatenated_default_value = ", ".join(values)
concatenated_default_value = f"[{concatenated_default_value}]"

parameter_type = f"List[{parameter_type}]"
default_value = concatenated_default_value
else:
enum_value = next((e.name for e in enum_type if e.value == default_value), None)
default_value = f"{parameter_type}.{enum_value}"

configuration_parameters.append(f"{parameter_name}: {parameter_type} = {default_value}")

# Use line separator and spaces to align the parameters appropriately in the generated file.
Expand All @@ -204,6 +250,7 @@ def get_output_parameters_with_type(
output_metadata: Dict[int, ParameterMetadata],
built_in_import_modules: List[str],
custom_import_modules: List[str],
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> str:
"""Returns the output parameters of the measurement with type."""
output_parameters_with_type = []
Expand All @@ -225,6 +272,10 @@ def get_output_parameters_with_type(
if metadata.repeated:
parameter_type = f"List[{parameter_type}]"

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type_name = _get_enum_type(metadata, enum_values_by_type).__name__
parameter_type = f"List[{enum_type_name}]" if metadata.repeated else enum_type_name

output_parameters_with_type.append(f"{parameter_name}: {parameter_type}")

return f"{os.linesep} ".join(output_parameters_with_type)
Expand Down Expand Up @@ -284,3 +335,35 @@ def _get_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
if is_array:
return f"List[{python_type.__name__}]"
return python_type.__name__


def _is_enum_param(parameter_type: int) -> bool:
return parameter_type == FieldDescriptorProto.TYPE_ENUM


def _get_enum_type(
parameter: Any, enum_values_by_type: Dict[Type[Enum], Dict[str, int]]
) -> Type[Enum]:
loaded_enum_values = json.loads(parameter.annotations["ni/enum.values"])
enum_values = {key: value for key, value in loaded_enum_values.items()}

for existing_enum_type, existing_enum_values in enum_values_by_type.items():
if existing_enum_values == enum_values:
return existing_enum_type

new_enum_type_name = _get_enum_class_name(parameter.name)
# MyPy error: Enum() expects a string literal as the first argument.
# Ignoring this error because MyPy cannot validate dynamic Enum creation statically.
new_enum_type = Enum(new_enum_type_name, enum_values) # type: ignore[misc]
enum_values_by_type[new_enum_type] = enum_values
return new_enum_type


def _get_enum_class_name(name: str) -> str:
name = re.sub(r"[^\w\s]", "", name).replace("_", " ")
split_string = name.split()
if len(split_string) > 1:
name = "".join(s.capitalize() for s in split_string)
else:
name = name[0].upper() + name[1:]
return name + "Enum"
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
<%page args="class_name, display_name, configuration_metadata, output_metadata, service_class, configuration_parameters_with_type_and_default_values, measure_api_parameters, output_parameters_with_type, built_in_import_modules, custom_import_modules"/>\
<%page args="class_name, display_name, configuration_metadata, output_metadata, service_class, configuration_parameters_with_type_and_default_values, measure_api_parameters, output_parameters_with_type, built_in_import_modules, custom_import_modules, enum_by_class_name"/>\
\
"""Generated client API for the ${display_name | repr} measurement plug-in."""

import logging
import pathlib
import threading
% if len(enum_by_class_name):
from enum import Enum
% endif
% for module in built_in_import_modules:
${module}
% endfor
Expand Down Expand Up @@ -35,6 +38,16 @@ _logger = logging.getLogger(__name__)

_V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService"

% for enum_name, enum_value in enum_by_class_name.items():

class ${enum_name.__name__}(Enum):
"""${enum_name.__name__} used for enum-typed measurement configs and outputs."""

% for key, val in enum_value.items():
${key} = ${val}
% endfor
% endfor

<% output_type = "None" %>\
% if output_metadata:

Expand Down Expand Up @@ -150,36 +163,9 @@ class ${class_name}:
return self._pin_map_client

def _create_file_descriptor(self) -> None:
metadata = self._get_stub().GetMetadata(v2_measurement_service_pb2.GetMetadataRequest())
configuration_metadata = []
for configuration in metadata.measurement_signature.configuration_parameters:
configuration_metadata.append(
ParameterMetadata.initialize(
display_name=configuration.name,
type=configuration.type,
repeated=configuration.repeated,
default_value=None,
annotations=dict(configuration.annotations.items()),
message_type=configuration.message_type,
)
)

output_metadata = []
for output in metadata.measurement_signature.outputs:
output_metadata.append(
ParameterMetadata.initialize(
display_name=output.name,
type=output.type,
repeated=output.repeated,
default_value=None,
annotations=dict(output.annotations.items()),
message_type=output.message_type,
)
)

create_file_descriptor(
input_metadata=configuration_metadata,
output_metadata=output_metadata,
input_metadata=list(self._configuration_metadata.values()),
output_metadata=list(self._output_metadata.values()),
service_name=self._service_class,
pool=descriptor_pool.Default(),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import pathlib
from enum import Enum
from types import ModuleType
from typing import Generator

Expand All @@ -11,6 +12,15 @@
from tests.utilities.measurements import non_streaming_data_measurement


class EnumInEnum(Enum):
"""EnumInEnum used for enum-typed measurement configs and outputs."""

NONE = 0
RED = 1
GREEN = 2
BLUE = 3


def test___measurement_plugin_client___measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
Expand Down Expand Up @@ -42,12 +52,16 @@ def test___measurement_plugin_client___measure___returns_output(
io_array_out=["resource1", "resource2"],
integer_out=10,
xy_data_out=None,
enum_out=EnumInEnum.BLUE,
enum_array_out=[EnumInEnum.RED, EnumInEnum.GREEN],
)
measurement_plugin_client = test_measurement_client_type()

response = measurement_plugin_client.measure()

assert response == expected_output
# Enum values are not comparable due to differing imports.
# So comparing values by converting them to string.
assert str(response) == str(expected_output)


def test___measurement_plugin_client___stream_measure___returns_output(
Expand Down Expand Up @@ -81,14 +95,18 @@ def test___measurement_plugin_client___stream_measure___returns_output(
io_array_out=["resource1", "resource2"],
integer_out=10,
xy_data_out=None,
enum_out=EnumInEnum.BLUE,
enum_array_out=[EnumInEnum.RED, EnumInEnum.GREEN],
)
measurement_plugin_client = test_measurement_client_type()

response_iterator = measurement_plugin_client.stream_measure()

responses = [response for response in response_iterator]
assert len(responses) == 1
assert responses[0] == expected_output
# Enum values are not comparable due to differing imports.
# So comparing values by converting them to string.
assert str(responses[0]) == str(expected_output)


@pytest.fixture(scope="module")
Expand Down
Loading

0 comments on commit 09b40b0

Please sign in to comment.