diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index b2488543b02..c0494272b34 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -70,6 +70,9 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): tags: Union[Dict[str, str], None] = None """AWS resource tags added to each table""" + session_based_auth: bool = False + """AWS session based client authentication""" + class DynamoDBOnlineStore(OnlineStore): """ @@ -104,10 +107,14 @@ def update( online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_client = self._get_dynamodb_client( - online_config.region, online_config.endpoint_url + online_config.region, + online_config.endpoint_url, + online_config.session_based_auth, ) dynamodb_resource = self._get_dynamodb_resource( - online_config.region, online_config.endpoint_url + online_config.region, + online_config.endpoint_url, + online_config.session_based_auth, ) # Add Tags attribute to creation request only if configured to prevent # TagResource permission issues, even with an empty Tags array. @@ -166,7 +173,9 @@ def teardown( online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_resource = self._get_dynamodb_resource( - online_config.region, online_config.endpoint_url + online_config.region, + online_config.endpoint_url, + online_config.session_based_auth, ) for table in tables: @@ -201,7 +210,9 @@ def online_write_batch( online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_resource = self._get_dynamodb_resource( - online_config.region, online_config.endpoint_url + online_config.region, + online_config.endpoint_url, + online_config.session_based_auth, ) table_instance = dynamodb_resource.Table( @@ -228,7 +239,9 @@ def online_read( assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_resource = self._get_dynamodb_resource( - online_config.region, online_config.endpoint_url + online_config.region, + online_config.endpoint_url, + online_config.session_based_auth, ) table_instance = dynamodb_resource.Table( _get_table_name(online_config, config, table) @@ -323,15 +336,27 @@ def _get_aioboto_session(self): def _get_aiodynamodb_client(self, region: str): return self._get_aioboto_session().create_client("dynamodb", region_name=region) - def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None): + def _get_dynamodb_client( + self, + region: str, + endpoint_url: Optional[str] = None, + session_based_auth: Optional[bool] = False, + ): if self._dynamodb_client is None: - self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url) + self._dynamodb_client = _initialize_dynamodb_client( + region, endpoint_url, session_based_auth + ) return self._dynamodb_client - def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None): + def _get_dynamodb_resource( + self, + region: str, + endpoint_url: Optional[str] = None, + session_based_auth: Optional[bool] = False, + ): if self._dynamodb_resource is None: self._dynamodb_resource = _initialize_dynamodb_resource( - region, endpoint_url + region, endpoint_url, session_based_auth ) return self._dynamodb_resource @@ -443,17 +468,38 @@ def _to_client_batch_get_payload(online_config, table_name, batch): } -def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None): - return boto3.client( - "dynamodb", - region_name=region, - endpoint_url=endpoint_url, - config=Config(user_agent=get_user_agent()), - ) +def _initialize_dynamodb_client( + region: str, + endpoint_url: Optional[str] = None, + session_based_auth: Optional[bool] = False, +): + if session_based_auth: + return boto3.Session().client( + "dynamodb", + region_name=region, + endpoint_url=endpoint_url, + config=Config(user_agent=get_user_agent()), + ) + else: + return boto3.client( + "dynamodb", + region_name=region, + endpoint_url=endpoint_url, + config=Config(user_agent=get_user_agent()), + ) -def _initialize_dynamodb_resource(region: str, endpoint_url: Optional[str] = None): - return boto3.resource("dynamodb", region_name=region, endpoint_url=endpoint_url) +def _initialize_dynamodb_resource( + region: str, + endpoint_url: Optional[str] = None, + session_based_auth: Optional[bool] = False, +): + if session_based_auth: + return boto3.Session().resource( + "dynamodb", region_name=region, endpoint_url=endpoint_url + ) + else: + return boto3.resource("dynamodb", region_name=region, endpoint_url=endpoint_url) # TODO(achals): This form of user-facing templating is experimental.