Skip to content

Commit 1d2f2d3

Browse files
authoredDec 13, 2021
Organize Step Function classes in Amazon provider (#20158)
* 20139 - organize aws step_function
1 parent 2213635 commit 1d2f2d3

11 files changed

+315
-252
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
import json
20+
from typing import Optional, Union
21+
22+
from airflow.exceptions import AirflowException
23+
from airflow.models import BaseOperator
24+
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
25+
26+
27+
class StepFunctionStartExecutionOperator(BaseOperator):
28+
"""
29+
An Operator that begins execution of an Step Function State Machine
30+
31+
Additional arguments may be specified and are passed down to the underlying BaseOperator.
32+
33+
.. seealso::
34+
:class:`~airflow.models.BaseOperator`
35+
36+
:param state_machine_arn: ARN of the Step Function State Machine
37+
:type state_machine_arn: str
38+
:param name: The name of the execution.
39+
:type name: Optional[str]
40+
:param state_machine_input: JSON data input to pass to the State Machine
41+
:type state_machine_input: Union[Dict[str, any], str, None]
42+
:param aws_conn_id: aws connection to uses
43+
:type aws_conn_id: str
44+
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
45+
:type do_xcom_push: bool
46+
"""
47+
48+
template_fields = ['state_machine_arn', 'name', 'input']
49+
template_ext = ()
50+
ui_color = '#f9c915'
51+
52+
def __init__(
53+
self,
54+
*,
55+
state_machine_arn: str,
56+
name: Optional[str] = None,
57+
state_machine_input: Union[dict, str, None] = None,
58+
aws_conn_id: str = 'aws_default',
59+
region_name: Optional[str] = None,
60+
**kwargs,
61+
):
62+
super().__init__(**kwargs)
63+
self.state_machine_arn = state_machine_arn
64+
self.name = name
65+
self.input = state_machine_input
66+
self.aws_conn_id = aws_conn_id
67+
self.region_name = region_name
68+
69+
def execute(self, context):
70+
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
71+
72+
execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
73+
74+
if execution_arn is None:
75+
raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}')
76+
77+
self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn)
78+
79+
return execution_arn
80+
81+
82+
class StepFunctionGetExecutionOutputOperator(BaseOperator):
83+
"""
84+
An Operator that begins execution of an Step Function State Machine
85+
86+
Additional arguments may be specified and are passed down to the underlying BaseOperator.
87+
88+
.. seealso::
89+
:class:`~airflow.models.BaseOperator`
90+
91+
:param execution_arn: ARN of the Step Function State Machine Execution
92+
:type execution_arn: str
93+
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
94+
:type aws_conn_id: str
95+
"""
96+
97+
template_fields = ['execution_arn']
98+
template_ext = ()
99+
ui_color = '#f9c915'
100+
101+
def __init__(
102+
self,
103+
*,
104+
execution_arn: str,
105+
aws_conn_id: str = 'aws_default',
106+
region_name: Optional[str] = None,
107+
**kwargs,
108+
):
109+
super().__init__(**kwargs)
110+
self.execution_arn = execution_arn
111+
self.aws_conn_id = aws_conn_id
112+
self.region_name = region_name
113+
114+
def execute(self, context):
115+
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
116+
117+
execution_status = hook.describe_execution(self.execution_arn)
118+
execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None
119+
120+
self.log.info('Got State Machine Execution output for %s', self.execution_arn)
121+
122+
return execution_output

‎airflow/providers/amazon/aws/operators/step_function_get_execution_output.py

+10-45
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,16 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import json
19-
from typing import Optional
18+
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`."""
2019

21-
from airflow.models import BaseOperator
22-
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
20+
import warnings
2321

22+
from airflow.providers.amazon.aws.operators.step_function import ( # noqa
23+
StepFunctionGetExecutionOutputOperator,
24+
)
2425

25-
class StepFunctionGetExecutionOutputOperator(BaseOperator):
26-
"""
27-
An Operator that begins execution of an Step Function State Machine
28-
29-
Additional arguments may be specified and are passed down to the underlying BaseOperator.
30-
31-
.. seealso::
32-
:class:`~airflow.models.BaseOperator`
33-
34-
:param execution_arn: ARN of the Step Function State Machine Execution
35-
:type execution_arn: str
36-
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
37-
:type aws_conn_id: str
38-
"""
39-
40-
template_fields = ['execution_arn']
41-
template_ext = ()
42-
ui_color = '#f9c915'
43-
44-
def __init__(
45-
self,
46-
*,
47-
execution_arn: str,
48-
aws_conn_id: str = 'aws_default',
49-
region_name: Optional[str] = None,
50-
**kwargs,
51-
):
52-
super().__init__(**kwargs)
53-
self.execution_arn = execution_arn
54-
self.aws_conn_id = aws_conn_id
55-
self.region_name = region_name
56-
57-
def execute(self, context):
58-
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
59-
60-
execution_status = hook.describe_execution(self.execution_arn)
61-
execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None
62-
63-
self.log.info('Got State Machine Execution output for %s', self.execution_arn)
64-
65-
return execution_output
26+
warnings.warn(
27+
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
28+
DeprecationWarning,
29+
stacklevel=2,
30+
)

‎airflow/providers/amazon/aws/operators/step_function_start_execution.py

+8-57
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Optional, Union
18+
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.step_function`."""
1919

20-
from airflow.exceptions import AirflowException
21-
from airflow.models import BaseOperator
22-
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
20+
import warnings
2321

22+
from airflow.providers.amazon.aws.operators.step_function import StepFunctionStartExecutionOperator # noqa
2423

25-
class StepFunctionStartExecutionOperator(BaseOperator):
26-
"""
27-
An Operator that begins execution of an Step Function State Machine
28-
29-
Additional arguments may be specified and are passed down to the underlying BaseOperator.
30-
31-
.. seealso::
32-
:class:`~airflow.models.BaseOperator`
33-
34-
:param state_machine_arn: ARN of the Step Function State Machine
35-
:type state_machine_arn: str
36-
:param name: The name of the execution.
37-
:type name: Optional[str]
38-
:param state_machine_input: JSON data input to pass to the State Machine
39-
:type state_machine_input: Union[Dict[str, any], str, None]
40-
:param aws_conn_id: aws connection to uses
41-
:type aws_conn_id: str
42-
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
43-
:type do_xcom_push: bool
44-
"""
45-
46-
template_fields = ['state_machine_arn', 'name', 'input']
47-
template_ext = ()
48-
ui_color = '#f9c915'
49-
50-
def __init__(
51-
self,
52-
*,
53-
state_machine_arn: str,
54-
name: Optional[str] = None,
55-
state_machine_input: Union[dict, str, None] = None,
56-
aws_conn_id: str = 'aws_default',
57-
region_name: Optional[str] = None,
58-
**kwargs,
59-
):
60-
super().__init__(**kwargs)
61-
self.state_machine_arn = state_machine_arn
62-
self.name = name
63-
self.input = state_machine_input
64-
self.aws_conn_id = aws_conn_id
65-
self.region_name = region_name
66-
67-
def execute(self, context):
68-
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
69-
70-
execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
71-
72-
if execution_arn is None:
73-
raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}')
74-
75-
self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn)
76-
77-
return execution_arn
24+
warnings.warn(
25+
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
26+
DeprecationWarning,
27+
stacklevel=2,
28+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import json
19+
from typing import Optional
20+
21+
from airflow.exceptions import AirflowException
22+
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
23+
from airflow.sensors.base import BaseSensorOperator
24+
25+
26+
class StepFunctionExecutionSensor(BaseSensorOperator):
27+
"""
28+
Asks for the state of the Step Function State Machine Execution until it
29+
reaches a failure state or success state.
30+
If it fails, failing the task.
31+
32+
On successful completion of the Execution the Sensor will do an XCom Push
33+
of the State Machine's output to `output`
34+
35+
:param execution_arn: execution_arn to check the state of
36+
:type execution_arn: str
37+
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
38+
:type aws_conn_id: str
39+
"""
40+
41+
INTERMEDIATE_STATES = ('RUNNING',)
42+
FAILURE_STATES = (
43+
'FAILED',
44+
'TIMED_OUT',
45+
'ABORTED',
46+
)
47+
SUCCESS_STATES = ('SUCCEEDED',)
48+
49+
template_fields = ['execution_arn']
50+
template_ext = ()
51+
ui_color = '#66c3ff'
52+
53+
def __init__(
54+
self,
55+
*,
56+
execution_arn: str,
57+
aws_conn_id: str = 'aws_default',
58+
region_name: Optional[str] = None,
59+
**kwargs,
60+
):
61+
super().__init__(**kwargs)
62+
self.execution_arn = execution_arn
63+
self.aws_conn_id = aws_conn_id
64+
self.region_name = region_name
65+
self.hook: Optional[StepFunctionHook] = None
66+
67+
def poke(self, context):
68+
execution_status = self.get_hook().describe_execution(self.execution_arn)
69+
state = execution_status['status']
70+
output = json.loads(execution_status['output']) if 'output' in execution_status else None
71+
72+
if state in self.FAILURE_STATES:
73+
raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}')
74+
75+
if state in self.INTERMEDIATE_STATES:
76+
return False
77+
78+
self.log.info('Doing xcom_push of output')
79+
self.xcom_push(context, 'output', output)
80+
return True
81+
82+
def get_hook(self) -> StepFunctionHook:
83+
"""Create and return a StepFunctionHook"""
84+
if self.hook:
85+
return self.hook
86+
87+
self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
88+
return self.hook

‎airflow/providers/amazon/aws/sensors/step_function_execution.py

+8-68
Original file line numberDiff line numberDiff line change
@@ -15,74 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import json
19-
from typing import Optional
18+
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`."""
2019

21-
from airflow.exceptions import AirflowException
22-
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
23-
from airflow.sensors.base import BaseSensorOperator
20+
import warnings
2421

22+
from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor # noqa
2523

26-
class StepFunctionExecutionSensor(BaseSensorOperator):
27-
"""
28-
Asks for the state of the Step Function State Machine Execution until it
29-
reaches a failure state or success state.
30-
If it fails, failing the task.
31-
32-
On successful completion of the Execution the Sensor will do an XCom Push
33-
of the State Machine's output to `output`
34-
35-
:param execution_arn: execution_arn to check the state of
36-
:type execution_arn: str
37-
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
38-
:type aws_conn_id: str
39-
"""
40-
41-
INTERMEDIATE_STATES = ('RUNNING',)
42-
FAILURE_STATES = (
43-
'FAILED',
44-
'TIMED_OUT',
45-
'ABORTED',
46-
)
47-
SUCCESS_STATES = ('SUCCEEDED',)
48-
49-
template_fields = ['execution_arn']
50-
template_ext = ()
51-
ui_color = '#66c3ff'
52-
53-
def __init__(
54-
self,
55-
*,
56-
execution_arn: str,
57-
aws_conn_id: str = 'aws_default',
58-
region_name: Optional[str] = None,
59-
**kwargs,
60-
):
61-
super().__init__(**kwargs)
62-
self.execution_arn = execution_arn
63-
self.aws_conn_id = aws_conn_id
64-
self.region_name = region_name
65-
self.hook: Optional[StepFunctionHook] = None
66-
67-
def poke(self, context):
68-
execution_status = self.get_hook().describe_execution(self.execution_arn)
69-
state = execution_status['status']
70-
output = json.loads(execution_status['output']) if 'output' in execution_status else None
71-
72-
if state in self.FAILURE_STATES:
73-
raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}')
74-
75-
if state in self.INTERMEDIATE_STATES:
76-
return False
77-
78-
self.log.info('Doing xcom_push of output')
79-
self.xcom_push(context, 'output', output)
80-
return True
81-
82-
def get_hook(self) -> StepFunctionHook:
83-
"""Create and return a StepFunctionHook"""
84-
if self.hook:
85-
return self.hook
86-
87-
self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
88-
return self.hook
24+
warnings.warn(
25+
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.step_function`.",
26+
DeprecationWarning,
27+
stacklevel=2,
28+
)

‎airflow/providers/amazon/provider.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ operators:
244244
python-modules:
245245
- airflow.providers.amazon.aws.operators.step_function_get_execution_output
246246
- airflow.providers.amazon.aws.operators.step_function_start_execution
247+
- airflow.providers.amazon.aws.operators.step_function
247248
- integration-name: Amazon Redshift
248249
python-modules:
249250
- airflow.providers.amazon.aws.operators.redshift
@@ -305,6 +306,7 @@ sensors:
305306
- integration-name: AWS Step Functions
306307
python-modules:
307308
- airflow.providers.amazon.aws.sensors.step_function_execution
309+
- airflow.providers.amazon.aws.sensors.step_function
308310

309311
hooks:
310312
- integration-name: Amazon Athena

‎dev/provider_packages/prepare_provider_packages.py

+2
Original file line numberDiff line numberDiff line change
@@ -2137,6 +2137,8 @@ def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnin
21372137
"This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.",
21382138
'numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header,'
21392139
' got 216 from PyObject',
2140+
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.step_function`.",
2141+
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
21402142
'This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.',
21412143
'This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.ec2`.',
21422144
}

‎tests/deprecated_classes.py

+14
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,16 @@
15181518
'airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator',
15191519
'airflow.operators.s3_file_transform_operator.S3FileTransformOperator',
15201520
),
1521+
(
1522+
'airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator',
1523+
'airflow.providers.amazon.aws.operators.step_function_start_execution'
1524+
'.StepFunctionStartExecutionOperator',
1525+
),
1526+
(
1527+
'airflow.providers.amazon.aws.operators.step_function.StepFunctionGetExecutionOutputOperator',
1528+
'airflow.providers.amazon.aws.operators.step_function_get_execution_output'
1529+
'.StepFunctionGetExecutionOutputOperator',
1530+
),
15211531
(
15221532
'airflow.providers.amazon.aws.sensors.s3_key.S3KeySensor',
15231533
'airflow.sensors.s3_key_sensor.S3KeySensor',
@@ -1526,6 +1536,10 @@
15261536
'airflow.providers.amazon.aws.sensors.s3_prefix.S3PrefixSensor',
15271537
'airflow.sensors.s3_prefix_sensor.S3PrefixSensor',
15281538
),
1539+
(
1540+
'airflow.providers.amazon.aws.sensors.step_function.StepFunctionExecutionSensor',
1541+
'airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionExecutionSensor',
1542+
),
15291543
(
15301544
'airflow.sensors.bash.BashSensor',
15311545
'airflow.contrib.sensors.bash_sensor.BashSensor',

‎tests/providers/amazon/aws/operators/test_step_function_start_execution.py ‎tests/providers/amazon/aws/operators/test_step_function.py

+57-9
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,80 @@
1515
# KIND, either express or implied. See the License for the
1616
# specific language governing permissions and limitations
1717
# under the License.
18-
#
1918

2019
import unittest
2120
from unittest import mock
2221
from unittest.mock import MagicMock
2322

24-
from airflow.providers.amazon.aws.operators.step_function_start_execution import (
23+
from airflow.providers.amazon.aws.operators.step_function import (
24+
StepFunctionGetExecutionOutputOperator,
2525
StepFunctionStartExecutionOperator,
2626
)
2727

28-
TASK_ID = 'step_function_start_execution_task'
28+
EXECUTION_ARN = (
29+
'arn:aws:states:us-east-1:123456789012:execution:'
30+
'pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934'
31+
)
32+
AWS_CONN_ID = 'aws_non_default'
33+
REGION_NAME = 'us-west-2'
2934
STATE_MACHINE_ARN = 'arn:aws:states:us-east-1:000000000000:stateMachine:pseudo-state-machine'
3035
NAME = 'NAME'
3136
INPUT = '{}'
32-
AWS_CONN_ID = 'aws_non_default'
33-
REGION_NAME = 'us-west-2'
37+
38+
39+
class TestStepFunctionGetExecutionOutputOperator(unittest.TestCase):
40+
TASK_ID = 'step_function_get_execution_output'
41+
42+
def setUp(self):
43+
self.mock_context = MagicMock()
44+
45+
def test_init(self):
46+
# Given / When
47+
operator = StepFunctionGetExecutionOutputOperator(
48+
task_id=self.TASK_ID,
49+
execution_arn=EXECUTION_ARN,
50+
aws_conn_id=AWS_CONN_ID,
51+
region_name=REGION_NAME,
52+
)
53+
54+
# Then
55+
assert self.TASK_ID == operator.task_id
56+
assert EXECUTION_ARN == operator.execution_arn
57+
assert AWS_CONN_ID == operator.aws_conn_id
58+
assert REGION_NAME == operator.region_name
59+
60+
@mock.patch('airflow.providers.amazon.aws.operators.step_function.StepFunctionHook')
61+
def test_execute(self, mock_hook):
62+
# Given
63+
hook_response = {'output': '{}'}
64+
65+
hook_instance = mock_hook.return_value
66+
hook_instance.describe_execution.return_value = hook_response
67+
68+
operator = StepFunctionGetExecutionOutputOperator(
69+
task_id=self.TASK_ID,
70+
execution_arn=EXECUTION_ARN,
71+
aws_conn_id=AWS_CONN_ID,
72+
region_name=REGION_NAME,
73+
)
74+
75+
# When
76+
result = operator.execute(self.mock_context)
77+
78+
# Then
79+
assert {} == result
3480

3581

3682
class TestStepFunctionStartExecutionOperator(unittest.TestCase):
83+
TASK_ID = 'step_function_start_execution_task'
84+
3785
def setUp(self):
3886
self.mock_context = MagicMock()
3987

4088
def test_init(self):
4189
# Given / When
4290
operator = StepFunctionStartExecutionOperator(
43-
task_id=TASK_ID,
91+
task_id=self.TASK_ID,
4492
state_machine_arn=STATE_MACHINE_ARN,
4593
name=NAME,
4694
state_machine_input=INPUT,
@@ -49,14 +97,14 @@ def test_init(self):
4997
)
5098

5199
# Then
52-
assert TASK_ID == operator.task_id
100+
assert self.TASK_ID == operator.task_id
53101
assert STATE_MACHINE_ARN == operator.state_machine_arn
54102
assert NAME == operator.name
55103
assert INPUT == operator.input
56104
assert AWS_CONN_ID == operator.aws_conn_id
57105
assert REGION_NAME == operator.region_name
58106

59-
@mock.patch('airflow.providers.amazon.aws.operators.step_function_start_execution.StepFunctionHook')
107+
@mock.patch('airflow.providers.amazon.aws.operators.step_function.StepFunctionHook')
60108
def test_execute(self, mock_hook):
61109
# Given
62110
hook_response = (
@@ -68,7 +116,7 @@ def test_execute(self, mock_hook):
68116
hook_instance.start_execution.return_value = hook_response
69117

70118
operator = StepFunctionStartExecutionOperator(
71-
task_id=TASK_ID,
119+
task_id=self.TASK_ID,
72120
state_machine_arn=STATE_MACHINE_ARN,
73121
name=NAME,
74122
state_machine_input=INPUT,

‎tests/providers/amazon/aws/operators/test_step_function_get_execution_output.py

-69
This file was deleted.

‎tests/providers/amazon/aws/sensors/test_step_function_execution.py ‎tests/providers/amazon/aws/sensors/test_step_function.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from parameterized import parameterized
2525

2626
from airflow.exceptions import AirflowException
27-
from airflow.providers.amazon.aws.sensors.step_function_execution import StepFunctionExecutionSensor
27+
from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor
2828

2929
TASK_ID = 'step_function_execution_sensor'
3030
EXECUTION_ARN = (
@@ -50,7 +50,7 @@ def test_init(self):
5050
assert REGION_NAME == sensor.region_name
5151

5252
@parameterized.expand([('FAILED',), ('TIMED_OUT',), ('ABORTED',)])
53-
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
53+
@mock.patch('airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook')
5454
def test_exceptions(self, mock_status, mock_hook):
5555
hook_response = {'status': mock_status}
5656

@@ -64,7 +64,7 @@ def test_exceptions(self, mock_status, mock_hook):
6464
with pytest.raises(AirflowException):
6565
sensor.poke(self.mock_context)
6666

67-
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
67+
@mock.patch('airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook')
6868
def test_running(self, mock_hook):
6969
hook_response = {'status': 'RUNNING'}
7070

@@ -77,7 +77,7 @@ def test_running(self, mock_hook):
7777

7878
assert not sensor.poke(self.mock_context)
7979

80-
@mock.patch('airflow.providers.amazon.aws.sensors.step_function_execution.StepFunctionHook')
80+
@mock.patch('airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook')
8181
def test_succeeded(self, mock_hook):
8282
hook_response = {'status': 'SUCCEEDED'}
8383

0 commit comments

Comments
 (0)
Please sign in to comment.