diff --git a/mirascope/core/bedrock/_utils/_setup_call.py b/mirascope/core/bedrock/_utils/_setup_call.py index 9d1441ced..b80099cf6 100644 --- a/mirascope/core/bedrock/_utils/_setup_call.py +++ b/mirascope/core/bedrock/_utils/_setup_call.py @@ -2,6 +2,7 @@ from __future__ import annotations +import os from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator from functools import wraps from typing import Any, ParamSpec, cast, overload @@ -223,12 +224,23 @@ def setup_call( call_kwargs |= cast(BedrockCallKwargs, {"modelId": model, "messages": messages}) + env_vars = {} + if access_key_id := os.getenv("AWS_ACCESS_KEY_ID"): + env_vars["aws_access_key_id"] = access_key_id + if secret_access_key := os.getenv("AWS_SECRET_ACCESS_KEY"): + env_vars["aws_secret_access_key"] = secret_access_key + if session_token := os.getenv("AWS_SESSION_TOKEN"): + env_vars["aws_session_token"] = session_token + if region_name := os.getenv("AWS_REGION_NAME"): + env_vars["region_name"] = region_name + if profile_name := os.getenv("AWS_PROFILE"): + env_vars["profile_name"] = profile_name if client is None: if fn_is_async(fn): - session = get_session() + session = get_session(env_vars=env_vars) _client = _AsyncBedrockRuntimeWrappedClient(session, model) else: - session = Session() + session = Session(**env_vars) _client = session.client("bedrock-runtime") else: _client = client diff --git a/pyproject.toml b/pyproject.toml index 041c41c8e..340cc2f5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mirascope" -version = "1.21.2" +version = "1.21.3" description = "LLM abstractions that aren't obstructions" readme = "README.md" license = { file = "LICENSE" } diff --git a/tests/core/bedrock/_utils/test_setup_call.py b/tests/core/bedrock/_utils/test_setup_call.py index 7a16c0484..4a42eb0ce 100644 --- a/tests/core/bedrock/_utils/test_setup_call.py +++ b/tests/core/bedrock/_utils/test_setup_call.py @@ -493,3 +493,80 @@ async def fake_stream(): mock_client.converse_stream.assert_called_once_with(param2="value2") mock_session.create_client.assert_called_once_with("bedrock-runtime") + + +@patch("mirascope.core.bedrock._utils._setup_call.Session") +@patch("mirascope.core.bedrock._utils._setup_call.get_session") +@patch("mirascope.core.bedrock._utils._setup_call._utils", new_callable=MagicMock) +def test_setup_call_env_vars( + mock_utils: MagicMock, + mock_get_session: MagicMock, + mock_session_class: MagicMock, + mock_base_setup_call: MagicMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that environment variables are properly passed to Session and get_session.""" + mock_utils.setup_call = mock_base_setup_call + mock_base_setup_call.return_value[1] = [ + {"role": "user", "content": [{"text": "user test"}]}, + ] + mock_base_setup_call.return_value[3] = {} + + # Set environment variables + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "test_session_token") + monkeypatch.setenv("AWS_REGION_NAME", "us-west-2") + monkeypatch.setenv("AWS_PROFILE", "test_profile") + + # Test sync client creation with env vars + setup_call( + model="anthropic.claude-v2", + client=None, + fn=MagicMock(), + fn_args={}, + dynamic_config=None, + tools=None, + json_mode=False, + call_params={}, + response_model=None, + stream=False, + ) + + # Verify Session was called with expected env vars + mock_session_class.assert_called_once_with( + aws_access_key_id="test_access_key", + aws_secret_access_key="test_secret_key", + aws_session_token="test_session_token", + region_name="us-west-2", + profile_name="test_profile", + ) + + # Test async client creation with env vars + async def async_fn(): ... + + mock_session_class.reset_mock() + + setup_call( + model="anthropic.claude-v2", + client=None, + fn=async_fn, + fn_args={}, + dynamic_config=None, + tools=None, + json_mode=False, + call_params={}, + response_model=None, + stream=False, + ) + + # Verify get_session was called with expected env vars + mock_get_session.assert_called_with( + env_vars={ + "aws_access_key_id": "test_access_key", + "aws_secret_access_key": "test_secret_key", + "aws_session_token": "test_session_token", + "region_name": "us-west-2", + "profile_name": "test_profile", + } + ) diff --git a/uv.lock b/uv.lock index 79671200f..554421aed 100644 --- a/uv.lock +++ b/uv.lock @@ -3168,7 +3168,7 @@ wheels = [ [[package]] name = "mirascope" -version = "1.21.1" +version = "1.21.2" source = { editable = "." } dependencies = [ { name = "docstring-parser" },