Skip to content

Commit 1d2f2d3

Browse files
authoredDec 13, 2021
Organize Step Function classes in Amazon provider (#20158)
* 20139 - organize aws step_function
1 parent 2213635 commit 1d2f2d3

11 files changed

+315
-252
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
import json
20+
from typing import Optional, Union
21+
22+
from airflow.exceptions import AirflowException
23+
from airflow.models import BaseOperator
24+
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
25+
26+
27+
class StepFunctionStartExecutionOperator(BaseOperator):
28+
"""
29+
An Operator that begins execution of an Step Function State Machine
30+
31+
Additional arguments may be specified and are passed down to the underlying BaseOperator.
32+
33+
.. seealso::
34+
:class:`~airflow.models.BaseOperator`
35+
36+
:param state_machine_arn: ARN of the Step Function State Machine
37+
:type state_machine_arn: str
38+
:param name: The name of the execution.
39+
:type name: Optional[str]
40+
:param state_machine_input: JSON data input to pass to the State Machine
41+
:type state_machine_input: Union[Dict[str, any], str, None]
42+
:param aws_conn_id: aws connection to uses
43+
:type aws_conn_id: str
44+
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
45+
:type do_xcom_push: bool
46+
"""
47+
48+
template_fields = ['state_machine_arn', 'name', 'input']
49+
template_ext = ()
50+
ui_color = '#f9c915'
51+
52+
def __init__(
53+
self,
54+
*,
55+
state_machine_arn: str,
56+
name: Optional[str] = None,
57+
state_machine_input: Union[dict, str, None] = None,
58+
aws_conn_id: str = 'aws_default',
59+
region_name: Optional[str] = None,
60+
**kwargs,
61+
):
62+
super().__init__(**kwargs)
63+
self.state_machine_arn = state_machine_arn
64+
self.name = name
65+
self.input = state_machine_input
66+
self.aws_conn_id = aws_conn_id
67+
self.region_name = region_name
68+
69+
def execute(self, context):
70+
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
71+
72+
execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
73+
74+
if execution_arn is None:
75+
raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}')
76+
77+
self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn)
78+
79+
return execution_arn
80+
81+
82+
class StepFunctionGetExecutionOutputOperator(BaseOperator):
83+
"""
84+
An Operator that begins execution of an Step Function State Machine
85+
86+
Additional arguments may be specified and are passed down to the underlying BaseOperator.
87+
88+
.. seealso::
89+
:class:`~airflow.models.BaseOperator`
90+
91+
:param execution_arn: ARN of the Step Function State Machine Execution
92+
:type execution_arn: str
93+
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
94+
:type aws_conn_id: str
95+
"""
96+
97+
template_fields = ['execution_arn']
98+
template_ext = ()
99+
ui_color = '#f9c915'
100+
101+
def __init__(
102+
self,
103+
*,
104+
execution_arn: str,
105+
aws_conn_id: str = 'aws_default',
106+
region_name: Optional[str] = None,
107+
**kwargs,
108+
):
109+
super().__init__(**kwargs)
110+
self.execution_arn = execution_arn
111+
self.aws_conn_id = aws_conn_id
112+
self.region_name = region_name
113+
114+
def execute(self, context):
115+
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
116+
117+
execution_status = hook.describe_execution(self.execution_arn)
118+
execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None
119+
120+
self.log.info('Got State Machine Execution output for %s', self.execution_arn)
121+
122+
return execution_output

‎airflow/providers/amazon/aws/operators/step_function_get_execution_output.py

+10-45
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,16 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import json
19-
from typing import Optional
18+
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`."""
2019

21-
from airflow.models import BaseOperator
22-
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
20+
import warnings
2321

22+
from airflow.providers.amazon.aws.operators.step_function import ( # noqa
23+
StepFunctionGetExecutionOutputOperator,
24+
)
2425

25-
class StepFunctionGetExecutionOutputOperator(BaseOperator):
26-
"""
27-
An Operator that begins execution of an Step Function State Machine
28-
29-
Additional arguments may be specified and are passed down to the underlying BaseOperator.
30-
31-
.. seealso::
32-
:class:`~airflow.models.BaseOperator`
33-
34-
:param execution_arn: ARN of the Step Function State Machine Execution
35-
:type execution_arn: str
36-
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
37-
:type aws_conn_id: str
38-
"""
39-
40-
template_fields = ['execution_arn']
41-
template_ext = ()
42-
ui_color = '#f9c915'
43-
44-
def __init__(
45-
self,
46-
*,
47-
execution_arn: str,
48-
aws_conn_id: str = 'aws_default',
49-
region_name: Optional[str] = None,
50-
**kwargs,
51-
):
52-
super().__init__(**kwargs)
53-
self.execution_arn = execution_arn
54-
self.aws_conn_id = aws_conn_id
55-
self.region_name = region_name
56-
57-
def execute(self, context):
58-
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
59-
60-
execution_status = hook.describe_execution(self.execution_arn)
61-
execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None
62-
63-
self.log.info('Got State Machine Execution output for %s', self.execution_arn)
64-
65-
return execution_output
26+
warnings.warn(
27+
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
28+
DeprecationWarning,
29+
stacklevel=2,
30+
)

‎airflow/providers/amazon/aws/operators/step_function_start_execution.py

+8-57
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Optional, Union
18+
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.step_function`."""
1919

20-
from airflow.exceptions import AirflowException
21-
from airflow.models import BaseOperator
22-
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
20+
import warnings
2321

22+
from airflow.providers.amazon.aws.operators.step_function import StepFunctionStartExecutionOperator # noqa
2423

25-
class StepFunctionStartExecutionOperator(BaseOperator):
26-
"""
27-
An Operator that begins execution of an Step Function State Machine
28-
29-
Additional arguments may be specified and are passed down to the underlying BaseOperator.
30-
31-
.. seealso::
32-
:class:`~airflow.models.BaseOperator`
33-
34-
:param state_machine_arn: ARN of the Step Function State Machine
35-
:type state_machine_arn: str
36-
:param name: The name of the execution.
37-
:type name: Optional[str]
38-
:param state_machine_input: JSON data input to pass to the State Machine
39-
:type state_machine_input: Union[Dict[str, any], str, None]
40-
:param aws_conn_id: aws connection to uses
41-
:type aws_conn_id: str
42-
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
43-
:type do_xcom_push: bool
44-
"""
45-
46-
template_fields = ['state_machine_arn', 'name', 'input']
47-
template_ext = ()
48-
ui_color = '#f9c915'
49-
50-
def __init__(
51-
self,
52-
*,
53-
state_machine_arn: str,
54-
name: Optional[str] = None,
55-
state_machine_input: Union[dict, str, None] = None,
56-
aws_conn_id: str = 'aws_default',
57-
region_name: Optional[str] = None,
58-
**kwargs,
59-
):
60-
super().__init__(**kwargs)
61-
self.state_machine_arn = state_machine_arn
62-
self.name = name
63-
self.input = state_machine_input
64-
self.aws_conn_id = aws_conn_id
65-
self.region_name = region_name
66-
67-
def execute(self, context):
68-
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
69-
70-
execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
71-
72-
if execution_arn is None:
73-
raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}')
74-
75-
self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn)
76-
77-
return execution_arn
24+
warnings.warn(
25+
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
26+
DeprecationWarning,
27+
stacklevel=2,
28+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import json
19+
from typing import Optional
20+
21+
from airflow.exceptions import AirflowException
22+
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
23+
from airflow.sensors.base import BaseSensorOperator
24+
25+
26+
class StepFunctionExecutionSensor(BaseSensorOperator):
27+
"""
28+
Asks for the state of the Step Function State Machine Execution until it
29+
reaches a failure state or success state.
30+
If it fails, failing the task.
31+
32+
On successful completion of the Execution the Sensor will do an XCom Push
33+
of the State Machine's output to `output`
34+
35+
:param execution_arn: execution_arn to check the state of
36+
:type execution_arn: str
37+
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
38+
:type aws_conn_id: str
39+
"""
40+
41+
INTERMEDIATE_STATES = ('RUNNING',)
42+
FAILURE_STATES = (
43+
'FAILED',
44+
'TIMED_OUT',
45+
'ABORTED',
46+
)
47+
SUCCESS_STATES = ('SUCCEEDED',)
48+
49+
template_fields = ['execution_arn']
50+
template_ext = ()
51+
ui_color = '#66c3ff'
52+
53+
def __init__(
54+
self,
55+
*,
56+
execution_arn: str,
57+
aws_conn_id: str = 'aws_default',
58+
region_name: Optional[str] = None,
59+
**kwargs,
60+
):
61+
super().__init__(**kwargs)
62+
self.execution_arn = execution_arn
63+
self.aws_conn_id = aws_conn_id
64+
self.region_name = region_name
65+
self.hook: Optional[StepFunctionHook] = None
66+
67+
def poke(self, context):
68+
execution_status = self.get_hook().describe_execution(self.execution_arn)
69+
state = execution_status['status']
70+
output = json.loads(execution_status['output']) if 'output' in execution_status else None
71+
72+
if state in self.FAILURE_STATES:
73+
raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}')
74+
75+
if state in self.INTERMEDIATE_STATES:
76+
return False
77+
78+
self.log.info('Doing xcom_push of output')
79+
self.xcom_push(context, 'output', output)
80+
return True
81+
82+
def get_hook(self) -> StepFunctionHook:
83+
"""Create and return a StepFunctionHook"""
84+
if self.hook:
85+
return self.hook
86+
87+
self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
88+
return self.hook

0 commit comments

Comments
 (0)