Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type-Checking the Zip operator on Dictionary should not crash #149

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions func_adl/ast/function_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def visit_SelectMany_of_SelectMany(self, parent: ast.Call, selection: ast.Lambda
"SelectMany", [cast(ast.AST, captured_body), cast(ast.AST, func_g)]
)
new_select_lambda = lambda_build(captured_arg, new_select)
new_selectmany = function_call("SelectMany", [seq, cast(ast.AST, new_select_lambda)])
return new_selectmany
new_select_many = function_call("SelectMany", [seq, cast(ast.AST, new_select_lambda)])
return new_select_many

def call_SelectMany(self, node: ast.Call, args: List[ast.AST]):
r"""
Expand Down Expand Up @@ -442,7 +442,7 @@ def visit_Subscript_Tuple(self, v: ast.Tuple, s: Union[ast.Num, ast.Constant, as
# Get the value out - this is due to supporting python 3.7-3.9
n = _get_value_from_index(s)
if n is None:
return ast.Subscript(v, s, ast.Load())
return ast.Subscript(v, s, ast.Load()) # type: ignore
assert isinstance(n, int), "Programming error: index is not an integer in tuple subscript"
if n >= len(v.elts):
raise FuncADLIndexError(
Expand All @@ -460,7 +460,7 @@ def visit_Subscript_List(self, v: ast.List, s: Union[ast.Num, ast.Constant, ast.
"""
n = _get_value_from_index(s)
if n is None:
return ast.Subscript(v, s, ast.Load())
return ast.Subscript(v, s, ast.Load()) # type: ignore
if n >= len(v.elts):
raise FuncADLIndexError(
f"Attempt to access the {n}th element of a tuple"
Expand All @@ -484,7 +484,7 @@ def visit_Subscript_Dict_with_value(self, v: ast.Dict, s: Union[str, int]):
if _get_value_from_index(value) == s:
return v.values[index]

return ast.Subscript(v, s, ast.Load())
return ast.Subscript(v, s, ast.Load()) # type: ignore

def visit_Subscript_Of_First(self, first: ast.AST, s):
"""
Expand Down
24 changes: 13 additions & 11 deletions func_adl/type_based_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def func_adl_callable(
Callable[[ObjectStream[W], ast.Call], Tuple[ObjectStream[W], ast.AST]]
] = None
):
"""Dectorator that will declare a function that can be used inline in
"""Decorator that will declare a function that can be used inline in
a `func_adl` expression. The body of the function, what the backend
translates it to, must be given by another route (e.g. via `MetaData`
and the `processor` argument).
Expand Down Expand Up @@ -388,7 +388,7 @@ def _fill_in_default_arguments(func: Callable, call: ast.Call) -> Tuple[ast.Call
)
return_type = Any

return call, return_type
return call, return_type # type: ignore


def fixup_ast_from_modifications(transformed_ast: ast.AST, original_ast: ast.Call) -> ast.Call:
Expand Down Expand Up @@ -582,7 +582,7 @@ def type_follow_in_callbacks(
# If this is a known collection class, we can use call-backs to follow it.
if get_origin(call_site_info.obj_type) in _g_collection_classes:
rtn_value = self.process_method_call_on_stream_obj(
_g_collection_classes[get_origin(call_site_info.obj_type)],
_g_collection_classes[get_origin(call_site_info.obj_type)], # type: ignore
m_name,
r_node,
get_args(call_site_info.obj_type)[0],
Expand Down Expand Up @@ -631,7 +631,7 @@ def process_method_call(self, node: ast.Call, obj_type: type) -> ast.AST:

Args:
node (ast.Call): The ast node
obj_type (type): The object type this method call is occuring against
obj_type (type): The object type this method call is occurring against

Returns:
ast.AST: An updated ast that is the new method call (with default args, etc.)
Expand All @@ -643,7 +643,7 @@ def process_method_call(self, node: ast.Call, obj_type: type) -> ast.AST:
base_obj_list_all = [obj_type]
if is_iterable(obj_type):
item_type = unwrap_iterable(obj_type)
base_obj_list_all += [c[item_type] for c in _g_collection_classes]
base_obj_list_all += [c[item_type] for c in _g_collection_classes] # type: ignore

assert isinstance(r_node.func, ast.Attribute)
m_name = r_node.func.attr
Expand Down Expand Up @@ -758,7 +758,7 @@ def process_function_call(self, node: ast.Call, func_info: _FuncAdlFunction) ->
f"function {func_info.function.__name__} ({str(e)})"
) from e

def process_paramaterized_method_call(
def process_parameterized_method_call(
self,
node: ast.Call,
obj_type: Type,
Expand Down Expand Up @@ -815,7 +815,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
if isinstance(t_node.func.value, ast.Attribute):
found_type = self.lookup_type(t_node.func.value.value)
if found_type is not None:
t_node = self.process_paramaterized_method_call(
t_node = self.process_parameterized_method_call(
t_node,
found_type,
t_node.func.value.attr,
Expand Down Expand Up @@ -941,17 +941,17 @@ def visit_Constant(self, node: ast.Constant) -> Any:
return node

def visit_Num(self, node: ast.Num) -> Any: # pragma: no cover
"3.7 compatability"
"3.7 compatibility"
self._found_types[node] = type(node.n)
return node

def visit_Str(self, node: ast.Str) -> Any: # pragma: no cover
"3.7 compatability"
"3.7 compatibility"
self._found_types[node] = str
return node

def visit_NameConstant(self, node: ast.NameConstant) -> Any: # pragma: no cover
"3.7 compatability"
"3.7 compatibility"
if node.value is None:
raise ValueError("Do not know how to work with pythons None")
self._found_types[node] = bool
Expand All @@ -968,6 +968,8 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
e for e, k in enumerate(t_node.value.keys) if k.value == key # type: ignore
]
if len(key_index) == 0:
if t_node.attr.lower() == "zip":
return t_node
raise ValueError(f"Key {key} not found in dict expression!!")
value = t_node.value.values[key_index[0]]
self._found_types[node] = self.lookup_type(value)
Expand Down Expand Up @@ -999,7 +1001,7 @@ def remap_from_lambda(
orig_type = o_stream.item_type
var_name = l_func.args.args[0].arg
stream, new_body, return_type = remap_by_types(o_stream, var_name, orig_type, l_func.body)
return stream, ast.Lambda(l_func.args, new_body), return_type
return stream, ast.Lambda(l_func.args, new_body), return_type # type: ignore


def reset_global_functions():
Expand Down
17 changes: 17 additions & 0 deletions tests/test_type_based_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,23 @@ def test_dictionary_bad_key():
assert "jetsss" in str(e)


def test_dictionary_Zip_key():
"Check that type follow from a dictionary through a Zip works"

s = ast_lambda(
"""({
'jet_pt': e.Jets().Select(lambda j: j.pt()),
'jet_eta': e.Jets().Select(lambda j: j.eta())}
.Zip()
.Select(lambda j: j.pt()))"""
)
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)

assert expr_type == Any


def test_dictionary_through_Select():
"""Make sure the Select statement carries the typing all the way through"""

Expand Down
Loading