From 17a5f5271425f8fa30718507c871591e46e26fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dandelion=20Man=C3=A9?= Date: Mon, 10 Mar 2025 17:33:45 -0400 Subject: [PATCH] fix: make llm.context work for async functions --- mirascope/llm/_call.py | 87 ++++++++++++--------- mirascope/llm/_context.py | 13 ++-- tests/llm/test_call.py | 154 +++++++++++++++++++++++++++++++++++++ tests/llm/test_context.py | 47 +++++++++++ tests/llm/test_override.py | 56 ++++++++++++++ 5 files changed, 316 insertions(+), 41 deletions(-) diff --git a/mirascope/llm/_call.py b/mirascope/llm/_call.py index 580f567ab..18a2c9e4d 100644 --- a/mirascope/llm/_call.py +++ b/mirascope/llm/_call.py @@ -20,7 +20,11 @@ from ..core.base._utils import fn_is_async from ..core.base.stream_config import StreamConfig from ..core.base.types import LocalProvider, Provider -from ._context import CallArgs, apply_context_overrides_to_call_args +from ._context import ( + CallArgs, + apply_context_overrides_to_call_args, + get_current_context, +) from ._protocols import ( AsyncLLMFunctionDecorator, CallDecorator, @@ -235,55 +239,66 @@ def wrapper( | Awaitable[(_ResponseModelT | CallResponse)], ]: if fn_is_async(fn): - + # Create a wrapper function that captures the current context when called @wraps(fn) - async def inner_async( + def wrapper_with_context( *args: _P.args, **kwargs: _P.kwargs - ) -> ( + ) -> Awaitable[ CallResponse | Stream | _ResponseModelT | _ParsedOutputT | (_ResponseModelT | CallResponse) - ): - # Apply any context overrides to the original call args - effective_call_args = apply_context_overrides_to_call_args( - original_call_args - ) + ]: + # Capture the context at call time + current_context = get_current_context() + + # Define an async function that uses the captured context + async def context_bound_inner_async() -> ( + CallResponse + | Stream + | _ResponseModelT + | _ParsedOutputT + | (_ResponseModelT | CallResponse) + ): + # Apply any context overrides to the original call args + effective_call_args = apply_context_overrides_to_call_args( + original_call_args, context_override=current_context + ) - # Get the appropriate provider call function with the possibly overridden provider - effective_provider = effective_call_args["provider"] - effective_client = effective_call_args["client"] + # Get the appropriate provider call function with the possibly overridden provider + effective_provider = effective_call_args["provider"] + effective_client = effective_call_args["client"] - if effective_provider in get_args(LocalProvider): - provider_call, effective_client = _get_local_provider_call( - cast(LocalProvider, effective_provider), - effective_client, - True, - ) - effective_call_args["client"] = effective_client - else: - provider_call = _get_provider_call( - cast(Provider, effective_provider) - ) + if effective_provider in get_args(LocalProvider): + provider_call, effective_client = _get_local_provider_call( + cast(LocalProvider, effective_provider), + effective_client, + True, + ) + effective_call_args["client"] = effective_client + else: + provider_call = _get_provider_call( + cast(Provider, effective_provider) + ) - # Use the provider-specific call function with overridden args - call_kwargs = dict(effective_call_args) - del call_kwargs[ - "provider" - ] # Remove provider as it's not a parameter to provider_call + # Use the provider-specific call function with overridden args + call_kwargs = dict(effective_call_args) + del call_kwargs["provider"] # Not a parameter to provider_call - # Get decorated function using provider_call - decorated = provider_call(**call_kwargs)(fn) + # Get decorated function using provider_call + decorated = provider_call(**call_kwargs)(fn) - # Call the decorated function and wrap the result - result = await decorated(*args, **kwargs) - return _wrap_result(result) + # Call the decorated function and wrap the result + result = await decorated(*args, **kwargs) + return _wrap_result(result) + + return context_bound_inner_async() - inner_async._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue] - inner_async._original_fn = fn # pyright: ignore [reportAttributeAccessIssue] + wrapper_with_context._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue] + wrapper_with_context._original_fn = fn # pyright: ignore [reportAttributeAccessIssue] - return inner_async + return wrapper_with_context # pyright: ignore [reportReturnType] else: @wraps(fn) diff --git a/mirascope/llm/_context.py b/mirascope/llm/_context.py index 4f6300bf8..8abddea5c 100644 --- a/mirascope/llm/_context.py +++ b/mirascope/llm/_context.py @@ -114,10 +114,10 @@ def _context( client: Any | None = None, # noqa: ANN401 call_params: CommonCallParams | Any | None = None, # noqa: ANN401 ) -> LLMContext: - """Context manager for synchronous LLM API calls. + """Context manager for LLM API calls. This is an internal method that allows both setting and structural overrides - for synchronous functions. + for LLM functions. Unfortunately we have not yet identified a way to properly type hint this because providing no structural overrides means the return type is that of the original @@ -138,7 +138,7 @@ def _context( client: The client to use for the LLM API call. call_params: The call parameters for the LLM API call. - Yields: + Returns: The context object that can be used to apply the context to a function. """ old_context: LLMContext | None = getattr(_current_context_local, "context", None) @@ -171,16 +171,19 @@ def _context( ) -def apply_context_overrides_to_call_args(call_args: CallArgs) -> CallArgs: +def apply_context_overrides_to_call_args( + call_args: CallArgs, context_override: LLMContext | None = None +) -> CallArgs: """Apply any active context overrides to the call arguments. Args: call_args: The original call arguments. + context_override: Optional explicit context to use instead of the current thread context. Returns: The call arguments with any context overrides applied. """ - context = get_current_context() + context = context_override or get_current_context() if not context: return call_args diff --git a/tests/llm/test_call.py b/tests/llm/test_call.py index 3c87af416..e7a6a9035 100644 --- a/tests/llm/test_call.py +++ b/tests/llm/test_call.py @@ -22,6 +22,7 @@ _wrap_result, call, ) +from mirascope.llm._context import context from mirascope.llm.call_response import CallResponse from mirascope.llm.stream import Stream @@ -384,3 +385,156 @@ async def dummy_local_async_function(): ... res = await dummy_local_async_function() assert isinstance(res, CallResponse) assert res.finish_reasons == ["stop"] + + +@pytest.mark.asyncio +async def test_context_in_async_function(): + """Test that context is properly applied in async functions.""" + # Create a mock provider call that captures the effective call args + captured_args = {} + + def dummy_async_provider_call( + model, + stream, + tools, + response_model, + output_parser, + json_mode, + call_params, + client, + ): + def wrapper(fn): + async def inner(*args, **kwargs): + # Store the args that were passed to the provider call + nonlocal captured_args + captured_args = { + "model": model, + "stream": stream, + "tools": tools, + "response_model": response_model, + "output_parser": output_parser, + "json_mode": json_mode, + "call_params": call_params, + "client": client, + } + + return ConcreteResponse( + metadata=Metadata(), + response={}, + tool_types=None, + prompt_template=None, + fn_args={}, + dynamic_config={}, + messages=[], + call_params=DummyCallParams(), + call_kwargs=BaseCallKwargs(), + user_message_param=None, + start_time=0, + end_time=0, + ) + + return inner + + return wrapper + + with patch( + "mirascope.llm._call._get_provider_call", + return_value=dummy_async_provider_call, + ): + # Create a function with the openai provider + @call(provider="openai", model="gpt-4o-mini") + async def dummy_async_function(): + pass # pragma: no cover + + # Call the function with a context that overrides the model + with context(provider="openai", model="gpt-4o"): + await dummy_async_function() + + # Check that the context override was applied + assert captured_args["model"] == "gpt-4o", ( + "Context model override was not applied in async function" + ) + + +@pytest.mark.asyncio +async def test_context_in_async_function_with_gather(): + """Test that context is properly applied in async functions when using asyncio.gather.""" + # Create a mock provider call that captures the effective call args for each call + captured_args_list = [] + + def dummy_async_provider_call( + model, + stream, + tools, + response_model, + output_parser, + json_mode, + call_params, + client, + ): + def wrapper(fn): + async def inner(*args, **kwargs): + # Store the args that were passed to the provider call + captured_args = { + "model": model, + "stream": stream, + "tools": tools, + "response_model": response_model, + "output_parser": output_parser, + "json_mode": json_mode, + "call_params": call_params, + "client": client, + } + captured_args_list.append(captured_args) + + return ConcreteResponse( + metadata=Metadata(), + response={}, + tool_types=None, + prompt_template=None, + fn_args={}, + dynamic_config={}, + messages=[], + call_params=DummyCallParams(), + call_kwargs=BaseCallKwargs(), + user_message_param=None, + start_time=0, + end_time=0, + ) + + return inner + + return wrapper + + with patch( + "mirascope.llm._call._get_provider_call", + return_value=dummy_async_provider_call, + ): + # Create a function with the openai provider + @call(provider="openai", model="gpt-4o-mini") + async def dummy_async_function(): + pass # pragma: no cover + + # Create futures first, then await them together + import asyncio + + # Create the first future with default provider/model + future1 = dummy_async_function() + + # Create the second future with a different context + with context(provider="anthropic", model="claude-3-5-sonnet"): + future2 = dummy_async_function() + + # Await both futures together + await asyncio.gather(future1, future2) + + # Check that we have two captured args + assert len(captured_args_list) == 2 + + # The first should use the original model + assert captured_args_list[0]["model"] == "gpt-4o-mini" + + # The second should use the context-overridden model + assert captured_args_list[1]["model"] == "claude-3-5-sonnet", ( + "Context model override was not applied when using asyncio.gather" + ) diff --git a/tests/llm/test_context.py b/tests/llm/test_context.py index 6eb5b5ad9..d9fe050d6 100644 --- a/tests/llm/test_context.py +++ b/tests/llm/test_context.py @@ -338,3 +338,50 @@ def test_context_function_with_client_and_params(): client=mock_client, call_params=call_params, ) + + +def test_apply_context_overrides_to_call_args_with_explicit_context(): + """Test applying context overrides with an explicitly provided context.""" + # Original call args + call_args: CallArgs = { + "provider": "openai", + "model": "gpt-4", + "stream": False, + "tools": None, + "response_model": None, + "output_parser": None, + "json_mode": False, + "client": None, + "call_params": None, + } + + # Create an explicit context + explicit_context = LLMContext( + provider="anthropic", + model="claude-3-5-sonnet", + stream=True, + ) + + # Apply the explicit context + result = apply_context_overrides_to_call_args( + call_args, context_override=explicit_context + ) + + # Check that the explicit context was applied + assert result["provider"] == "anthropic" + assert result["model"] == "claude-3-5-sonnet" + assert result["stream"] is True + + # Test that it takes precedence over the current context + with _context( + provider="openai", + model="gpt-4o", + stream=False, + ): + # The explicit context should be used, not the current context + result = apply_context_overrides_to_call_args( + call_args, context_override=explicit_context + ) + assert result["provider"] == "anthropic" + assert result["model"] == "claude-3-5-sonnet" + assert result["stream"] is True diff --git a/tests/llm/test_override.py b/tests/llm/test_override.py index aad563fa9..af74d2622 100644 --- a/tests/llm/test_override.py +++ b/tests/llm/test_override.py @@ -121,3 +121,59 @@ def test_override_with_client(): mock_context.assert_called_once() _, kwargs = mock_context.call_args assert kwargs["client"] is new_client + + +@pytest.mark.asyncio +async def test_override_with_async_gather(): + """Test that override works correctly with asyncio.gather.""" + import asyncio + + # Mock the provider_agnostic_call + async def mock_async_fn(): + return MagicMock() + + # Mock the _context to capture the context parameters + with patch("mirascope.llm._override._context") as mock_context: + # Set up the mock to record calls and return a context manager + context_manager = MagicMock() + context_manager.__enter__ = MagicMock(return_value=None) + context_manager.__exit__ = MagicMock(return_value=None) + mock_context.return_value = context_manager + + # Create two overridden functions with different providers/models + openai_fn = override( + provider_agnostic_call=mock_async_fn, + provider="openai", + model="gpt-4o-mini", + call_params=None, + ) + + anthropic_fn = override( + provider_agnostic_call=mock_async_fn, + provider="anthropic", + model="claude-3-5-sonnet", + call_params=None, + ) + + # Create futures for both functions + openai_future = openai_fn() + anthropic_future = anthropic_fn() + + # Await both futures together + await asyncio.gather(openai_future, anthropic_future) + + # Check that both contexts were created with the correct parameters + assert mock_context.call_count == 2 + + # Extract the call arguments + call_args_list = mock_context.call_args_list + + # First call should be for openai + _, kwargs1 = call_args_list[0] + assert kwargs1["provider"] == "openai" + assert kwargs1["model"] == "gpt-4o-mini" + + # Second call should be for anthropic + _, kwargs2 = call_args_list[1] + assert kwargs2["provider"] == "anthropic" + assert kwargs2["model"] == "claude-3-5-sonnet"