-
Notifications
You must be signed in to change notification settings - Fork 14.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Amazon EMR on Amazon EKS #17178
Amazon EMR on Amazon EKS #17178
Changes from 3 commits
b5e2993
dbc2cf3
f0e765b
c978e0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from typing import Any, Dict, Optional | ||
|
||
import botocore.exceptions | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook | ||
|
||
|
||
class EmrContainersHook(AwsBaseHook): | ||
""" | ||
Interact with AWS EMR for EKS. | ||
|
||
Additional arguments (such as ``aws_conn_id``) may be specified and | ||
are passed down to the underlying AwsBaseHook. | ||
|
||
.. seealso:: | ||
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` | ||
""" | ||
|
||
hook_name = 'Elastic MapReduce Containers' | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
kwargs["client_type"] = "emr-containers" | ||
super().__init__(*args, **kwargs) | ||
|
||
def handle_aws_client_error(self, error_response: dict) -> None: | ||
"""Logs errors from the client | ||
|
||
:param error_response: A dictionary for AWS service exceptions | ||
:type error_response: str | ||
""" | ||
self.log.error("%s: %s", self.hook_name, error_response.get("Code")) | ||
self.log.error("%s: %s", self.hook_name, error_response.get("Message")) | ||
|
||
def get_job_by_id(self, job_id: str, cluster_id: str) -> Optional[dict]: | ||
"""Get job details by job id and virtual cluster id. | ||
|
||
If the job is found, returns a response describing the job. | ||
|
||
:param job_id: The ID of the job run request | ||
:type job_id: str | ||
:param cluster_id: The ID of the virtual cluster for which the job run is submitted | ||
:type cluster_id: str | ||
:return: A dictionary representing the job | ||
:rtype: dict | ||
""" | ||
try: | ||
return self.get_conn().describe_job_run(id=job_id, virtualClusterId=cluster_id) | ||
except botocore.exceptions.ClientError as err: | ||
error_response = err.response.get("Error", {}) | ||
self.handle_aws_client_error(error_response) | ||
if error_response.get("Code") in ("ValidationException", "InternalServerException"): | ||
raise AirflowException(error_response.get("Message")) | ||
|
||
def start_job( | ||
self, | ||
cluster_id: str, | ||
execution_role_arn: str, | ||
emr_release_label: str, | ||
job_driver: Dict[str, Any], | ||
**kwargs: Any, | ||
) -> Dict[str, str]: | ||
"""Starts a spark job using EMR in EKS | ||
|
||
:param cluster_id: The ID of the virtual cluster for which the job run is submitted | ||
:type cluster_id: str | ||
:param execution_role_arn: The execution role ARN for the job run. | ||
:type execution_role_arn: str | ||
:param emr_release_label: The Amazon EMR release version to use for the job run. | ||
:type emr_release_label: str | ||
:param job_driver: The job driver for the job run. | ||
:type job_driver: dict | ||
:param configuration_overrides: The configuration overrides for the job run. | ||
:type configuration_overrides: dict | ||
:param tags: The tags assigned to job runs | ||
:type tags: dict | ||
:param name: The name of the job run. | ||
:type name: str | ||
:param client_token: The client idempotency token of the job run request. Provided if not populated | ||
:type client_token: str | ||
:return: A response with job run details | ||
:rtype: dict | ||
|
||
""" | ||
params = { | ||
"virtualClusterId": cluster_id, | ||
"executionRoleArn": execution_role_arn, | ||
"releaseLabel": emr_release_label, | ||
"jobDriver": job_driver, | ||
} | ||
optional_params = ( | ||
("configurationOverrides", "configuration_overrides"), | ||
("name", "name"), | ||
("clientToken", "client_token"), | ||
("tags", "tags"), | ||
) | ||
for aws_var_name, airflow_var_name in optional_params: | ||
if kwargs.get(airflow_var_name): | ||
params[aws_var_name] = kwargs[airflow_var_name] | ||
|
||
try: | ||
return self.get_conn().start_job_run(**params) | ||
except botocore.exceptions.ClientError as err: | ||
error_response = err.response.get("Error", {}) | ||
self.handle_aws_client_error(error_response) | ||
raise AirflowException(error_response.get("Message")) | ||
|
||
def terminate_job_by_id(self, job_id: str, cluster_id: str) -> Optional[dict]: | ||
"""Terminates a job by job id and virtual cluster id. | ||
|
||
If the job is found, returns a response with job id and cluster id. | ||
|
||
:param job_id: The ID of the job run request | ||
:type job_id: str | ||
:param cluster_id: The ID of the virtual cluster for which the job run is submitted | ||
:type cluster_id: str | ||
:return: A dictionary representing the job | ||
:rtype: dict | ||
""" | ||
try: | ||
return self.get_conn().cancel_job_run(id=job_id, virtualClusterId=cluster_id) | ||
except botocore.exceptions.ClientError as err: | ||
error_response = err.response.get("Error", {}) | ||
self.handle_aws_client_error(error_response) | ||
raise AirflowException(error_response.get("Message")) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Dict | ||
|
||
from airflow.models import BaseOperator | ||
from airflow.providers.amazon.aws.hooks.emr_containers import EmrContainersHook | ||
|
||
|
||
class EmrContainersCancelJobOperator(BaseOperator): | ||
"""Operator to a cancel a job run. | ||
|
||
Cancels a job run. A job run is a unit of work, | ||
such as a Spark jar, PySpark script, or SparkSQL query, that you submit to Amazon EMR on EKS. | ||
|
||
:param job_id: The ID of the job run to cancel. | ||
:type job_id: str | ||
:param cluster_id: The ID of the virtual cluster for which the job run will be canceled | ||
:type cluster_id: str | ||
:param aws_conn_id: aws connection to uses | ||
:type aws_conn_id: str | ||
""" | ||
|
||
ui_color = '#f9c915' | ||
|
||
def __init__(self, *, job_id: str, cluster_id: str, aws_conn_id: str = 'aws_default', **kwargs): | ||
super().__init__(**kwargs) | ||
self.job_id = job_id | ||
self.cluster_id = cluster_id | ||
self.aws_conn_id = aws_conn_id | ||
|
||
def execute(self, context: Dict[str, Any]) -> str: | ||
emr_containers = EmrContainersHook(aws_conn_id=self.aws_conn_id) | ||
|
||
self.log.info('Cancelling EMR Containers job %s', self.job_id) | ||
response = emr_containers.terminate_job_by_id(job_id=self.job_id, cluster_id=self.cluster_id) | ||
self.log.info('EMR Containers job %s has been cancelled', response["id"]) | ||
return response["id"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Dict | ||
|
||
from airflow.models import BaseOperator | ||
from airflow.providers.amazon.aws.hooks.emr_containers import EmrContainersHook | ||
|
||
|
||
class EmrContainersGetJobStateOperator(BaseOperator): | ||
"""Operator to get a job status. | ||
|
||
A job run is a unit of work, such as a Spark jar, PySpark script, or SparkSQL query, | ||
that you submit to Amazon EMR on EKS. | ||
|
||
A job is in PENDING, SUBMITTED, RUNNING, FAILED, CANCELLED or CANCEL state | ||
wanderijames marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
:param job_id: The ID of the job run request | ||
:type job_id: str | ||
:param cluster_id: The ID of the virtual cluster for which the job run is submitted | ||
:type cluster_id: str | ||
:param aws_conn_id: aws connection to uses | ||
:type aws_conn_id: str | ||
""" | ||
|
||
ui_color = '#f9c915' | ||
|
||
def __init__( | ||
self, *, job_id: str, cluster_id: str, aws_conn_id: str = 'aws_default', **kwargs: Any | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.job_id = job_id | ||
self.cluster_id = cluster_id | ||
self.aws_conn_id = aws_conn_id | ||
|
||
def execute(self, context: Dict[str, Any]) -> str: | ||
"""Check a job state in EMR EKS | ||
|
||
:return: A job state | ||
:rtype: str | ||
""" | ||
emr_containers = EmrContainersHook(aws_conn_id=self.aws_conn_id) | ||
|
||
self.log.info('Checking job %s state in cluster %s', self.job_id, self.cluster_id) | ||
response = emr_containers.get_job_by_id(self.job_id, self.cluster_id) | ||
return response["jobRun"]["state"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
from airflow.models import BaseOperator | ||
from airflow.providers.amazon.aws.hooks.emr_containers import EmrContainersHook | ||
|
||
|
||
class EmrContainersStartJobOperator(BaseOperator): | ||
"""Operator to start a job. | ||
|
||
A job run is a unit of work, such as a Spark jar, PySpark script, or SparkSQL query, | ||
that you submit to Amazon EMR on EKS. | ||
|
||
:param cluster_id: The ID of the virtual cluster for which the job run is submitted | ||
:type cluster_id: str | ||
:param execution_role_arn: The execution role ARN for the job run. | ||
:type execution_role_arn: str | ||
:param emr_release_label: The Amazon EMR release version to use for the job run. | ||
:type emr_release_label: str | ||
:param job_driver: The job driver for the job run. | ||
:type job_driver: dict | ||
:param configuration_overrides: The configuration overrides for the job run. | ||
:type configuration_overrides: dict | ||
:param tags: The tags assigned to job runs | ||
:type tags: dict | ||
:param name: The name of the job run. | ||
:type name: str | ||
:param client_token: The client idempotency token of the job run request. Provided if not populated | ||
:type client_token: str | ||
:param aws_conn_id: aws connection to uses | ||
:type aws_conn_id: str | ||
""" | ||
|
||
ui_color = '#f9c915' | ||
|
||
def __init__( | ||
self, | ||
*, | ||
cluster_id: str, | ||
execution_role_arn: str, | ||
emr_release_label: str, | ||
job_driver: Dict[str, Any], | ||
configuration_overrides: Optional[dict] = None, | ||
tags: Optional[dict] = None, | ||
name: Optional[str] = None, | ||
client_token: Optional[str] = None, | ||
aws_conn_id: str = 'aws_default', | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.start_job_params = dict( | ||
cluster_id=cluster_id, | ||
execution_role_arn=execution_role_arn, | ||
emr_release_label=emr_release_label, | ||
job_driver=job_driver, | ||
configuration_overrides=configuration_overrides, | ||
tags=tags, | ||
name=name, | ||
client_token=client_token, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see you defining There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is being used in line 89; response = emr_containers.start_job(**self.start_job_params) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, but it's not initialized anywhere. The default is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @dacort for this. |
||
) | ||
self.aws_conn_id = aws_conn_id | ||
|
||
def execute(self, context: Dict[str, Any]) -> str: | ||
"""Start a job in EMR EKS | ||
|
||
:return: A job id | ||
:rtype: str | ||
""" | ||
emr_containers = EmrContainersHook(aws_conn_id=self.aws_conn_id) | ||
|
||
self.log.info('Starting job in EMR Containers') | ||
response = emr_containers.start_job(**self.start_job_params) | ||
self.log.info( | ||
'Job %s has been started in EMR Containers in cluster %s', | ||
response["id"], | ||
response["virtualClusterId"], | ||
) | ||
return response["id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious what the use case is for an operator that just returns the job state?
Apologies if this is a silly question - still getting up to speed on typical Airflow patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that you asked me that question, I don't think this has a straight forward use case as an operator.