Skip to content

Commit

Permalink
Merge pull request #909 from teamdandelion/fix-llm-context-async
Browse files Browse the repository at this point in the history
fix: make llm.context work for async functions
  • Loading branch information
willbakst authored Mar 10, 2025
2 parents 0611264 + 17a5f52 commit 38c4890
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 41 deletions.
87 changes: 51 additions & 36 deletions mirascope/llm/_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from ..core.base._utils import fn_is_async
from ..core.base.stream_config import StreamConfig
from ..core.base.types import LocalProvider, Provider
from ._context import CallArgs, apply_context_overrides_to_call_args
from ._context import (
CallArgs,
apply_context_overrides_to_call_args,
get_current_context,
)
from ._protocols import (
AsyncLLMFunctionDecorator,
CallDecorator,
Expand Down Expand Up @@ -235,55 +239,66 @@ def wrapper(
| Awaitable[(_ResponseModelT | CallResponse)],
]:
if fn_is_async(fn):

# Create a wrapper function that captures the current context when called
@wraps(fn)
async def inner_async(
def wrapper_with_context(
*args: _P.args, **kwargs: _P.kwargs
) -> (
) -> Awaitable[
CallResponse
| Stream
| _ResponseModelT
| _ParsedOutputT
| (_ResponseModelT | CallResponse)
):
# Apply any context overrides to the original call args
effective_call_args = apply_context_overrides_to_call_args(
original_call_args
)
]:
# Capture the context at call time
current_context = get_current_context()

# Define an async function that uses the captured context
async def context_bound_inner_async() -> (
CallResponse
| Stream
| _ResponseModelT
| _ParsedOutputT
| (_ResponseModelT | CallResponse)
):
# Apply any context overrides to the original call args
effective_call_args = apply_context_overrides_to_call_args(
original_call_args, context_override=current_context
)

# Get the appropriate provider call function with the possibly overridden provider
effective_provider = effective_call_args["provider"]
effective_client = effective_call_args["client"]
# Get the appropriate provider call function with the possibly overridden provider
effective_provider = effective_call_args["provider"]
effective_client = effective_call_args["client"]

if effective_provider in get_args(LocalProvider):
provider_call, effective_client = _get_local_provider_call(
cast(LocalProvider, effective_provider),
effective_client,
True,
)
effective_call_args["client"] = effective_client
else:
provider_call = _get_provider_call(
cast(Provider, effective_provider)
)
if effective_provider in get_args(LocalProvider):
provider_call, effective_client = _get_local_provider_call(
cast(LocalProvider, effective_provider),
effective_client,
True,
)
effective_call_args["client"] = effective_client
else:
provider_call = _get_provider_call(
cast(Provider, effective_provider)
)

# Use the provider-specific call function with overridden args
call_kwargs = dict(effective_call_args)
del call_kwargs[
"provider"
] # Remove provider as it's not a parameter to provider_call
# Use the provider-specific call function with overridden args
call_kwargs = dict(effective_call_args)
del call_kwargs["provider"] # Not a parameter to provider_call

# Get decorated function using provider_call
decorated = provider_call(**call_kwargs)(fn)
# Get decorated function using provider_call
decorated = provider_call(**call_kwargs)(fn)

# Call the decorated function and wrap the result
result = await decorated(*args, **kwargs)
return _wrap_result(result)
# Call the decorated function and wrap the result
result = await decorated(*args, **kwargs)
return _wrap_result(result)

return context_bound_inner_async()

inner_async._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
inner_async._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
wrapper_with_context._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
wrapper_with_context._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]

return inner_async
return wrapper_with_context # pyright: ignore [reportReturnType]
else:

@wraps(fn)
Expand Down
13 changes: 8 additions & 5 deletions mirascope/llm/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def _context(
client: Any | None = None, # noqa: ANN401
call_params: CommonCallParams | Any | None = None, # noqa: ANN401
) -> LLMContext:
"""Context manager for synchronous LLM API calls.
"""Context manager for LLM API calls.
This is an internal method that allows both setting and structural overrides
for synchronous functions.
for LLM functions.
Unfortunately we have not yet identified a way to properly type hint this because
providing no structural overrides means the return type is that of the original
Expand All @@ -138,7 +138,7 @@ def _context(
client: The client to use for the LLM API call.
call_params: The call parameters for the LLM API call.
Yields:
Returns:
The context object that can be used to apply the context to a function.
"""
old_context: LLMContext | None = getattr(_current_context_local, "context", None)
Expand Down Expand Up @@ -171,16 +171,19 @@ def _context(
)


def apply_context_overrides_to_call_args(call_args: CallArgs) -> CallArgs:
def apply_context_overrides_to_call_args(
call_args: CallArgs, context_override: LLMContext | None = None
) -> CallArgs:
"""Apply any active context overrides to the call arguments.
Args:
call_args: The original call arguments.
context_override: Optional explicit context to use instead of the current thread context.
Returns:
The call arguments with any context overrides applied.
"""
context = get_current_context()
context = context_override or get_current_context()
if not context:
return call_args

Expand Down
154 changes: 154 additions & 0 deletions tests/llm/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_wrap_result,
call,
)
from mirascope.llm._context import context
from mirascope.llm.call_response import CallResponse
from mirascope.llm.stream import Stream

Expand Down Expand Up @@ -384,3 +385,156 @@ async def dummy_local_async_function(): ...
res = await dummy_local_async_function()
assert isinstance(res, CallResponse)
assert res.finish_reasons == ["stop"]


@pytest.mark.asyncio
async def test_context_in_async_function():
"""Test that context is properly applied in async functions."""
# Create a mock provider call that captures the effective call args
captured_args = {}

def dummy_async_provider_call(
model,
stream,
tools,
response_model,
output_parser,
json_mode,
call_params,
client,
):
def wrapper(fn):
async def inner(*args, **kwargs):
# Store the args that were passed to the provider call
nonlocal captured_args
captured_args = {
"model": model,
"stream": stream,
"tools": tools,
"response_model": response_model,
"output_parser": output_parser,
"json_mode": json_mode,
"call_params": call_params,
"client": client,
}

return ConcreteResponse(
metadata=Metadata(),
response={},
tool_types=None,
prompt_template=None,
fn_args={},
dynamic_config={},
messages=[],
call_params=DummyCallParams(),
call_kwargs=BaseCallKwargs(),
user_message_param=None,
start_time=0,
end_time=0,
)

return inner

return wrapper

with patch(
"mirascope.llm._call._get_provider_call",
return_value=dummy_async_provider_call,
):
# Create a function with the openai provider
@call(provider="openai", model="gpt-4o-mini")
async def dummy_async_function():
pass # pragma: no cover

# Call the function with a context that overrides the model
with context(provider="openai", model="gpt-4o"):
await dummy_async_function()

# Check that the context override was applied
assert captured_args["model"] == "gpt-4o", (
"Context model override was not applied in async function"
)


@pytest.mark.asyncio
async def test_context_in_async_function_with_gather():
"""Test that context is properly applied in async functions when using asyncio.gather."""
# Create a mock provider call that captures the effective call args for each call
captured_args_list = []

def dummy_async_provider_call(
model,
stream,
tools,
response_model,
output_parser,
json_mode,
call_params,
client,
):
def wrapper(fn):
async def inner(*args, **kwargs):
# Store the args that were passed to the provider call
captured_args = {
"model": model,
"stream": stream,
"tools": tools,
"response_model": response_model,
"output_parser": output_parser,
"json_mode": json_mode,
"call_params": call_params,
"client": client,
}
captured_args_list.append(captured_args)

return ConcreteResponse(
metadata=Metadata(),
response={},
tool_types=None,
prompt_template=None,
fn_args={},
dynamic_config={},
messages=[],
call_params=DummyCallParams(),
call_kwargs=BaseCallKwargs(),
user_message_param=None,
start_time=0,
end_time=0,
)

return inner

return wrapper

with patch(
"mirascope.llm._call._get_provider_call",
return_value=dummy_async_provider_call,
):
# Create a function with the openai provider
@call(provider="openai", model="gpt-4o-mini")
async def dummy_async_function():
pass # pragma: no cover

# Create futures first, then await them together
import asyncio

# Create the first future with default provider/model
future1 = dummy_async_function()

# Create the second future with a different context
with context(provider="anthropic", model="claude-3-5-sonnet"):
future2 = dummy_async_function()

# Await both futures together
await asyncio.gather(future1, future2)

# Check that we have two captured args
assert len(captured_args_list) == 2

# The first should use the original model
assert captured_args_list[0]["model"] == "gpt-4o-mini"

# The second should use the context-overridden model
assert captured_args_list[1]["model"] == "claude-3-5-sonnet", (
"Context model override was not applied when using asyncio.gather"
)
47 changes: 47 additions & 0 deletions tests/llm/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,50 @@ def test_context_function_with_client_and_params():
client=mock_client,
call_params=call_params,
)


def test_apply_context_overrides_to_call_args_with_explicit_context():
"""Test applying context overrides with an explicitly provided context."""
# Original call args
call_args: CallArgs = {
"provider": "openai",
"model": "gpt-4",
"stream": False,
"tools": None,
"response_model": None,
"output_parser": None,
"json_mode": False,
"client": None,
"call_params": None,
}

# Create an explicit context
explicit_context = LLMContext(
provider="anthropic",
model="claude-3-5-sonnet",
stream=True,
)

# Apply the explicit context
result = apply_context_overrides_to_call_args(
call_args, context_override=explicit_context
)

# Check that the explicit context was applied
assert result["provider"] == "anthropic"
assert result["model"] == "claude-3-5-sonnet"
assert result["stream"] is True

# Test that it takes precedence over the current context
with _context(
provider="openai",
model="gpt-4o",
stream=False,
):
# The explicit context should be used, not the current context
result = apply_context_overrides_to_call_args(
call_args, context_override=explicit_context
)
assert result["provider"] == "anthropic"
assert result["model"] == "claude-3-5-sonnet"
assert result["stream"] is True
Loading

0 comments on commit 38c4890

Please sign in to comment.