Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amazon EMR on Amazon EKS #17178

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr_containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional

import botocore.exceptions

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


class EmrContainersHook(AwsBaseHook):
"""
Interact with AWS EMR for EKS.

Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.

.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

hook_name = 'Elastic MapReduce Containers'

def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "emr-containers"
super().__init__(*args, **kwargs)

def handle_aws_client_error(self, error_response: dict) -> None:
"""Logs errors from the client

:param error_response: A dictionary for AWS service exceptions
:type error_response: str
"""
self.log.error("%s: %s", self.hook_name, error_response.get("Code"))
self.log.error("%s: %s", self.hook_name, error_response.get("Message"))

def get_job_by_id(self, job_id: str, cluster_id: str) -> Optional[dict]:
"""Get job details by job id and virtual cluster id.

If the job is found, returns a response describing the job.

:param job_id: The ID of the job run request
:type job_id: str
:param cluster_id: The ID of the virtual cluster for which the job run is submitted
:type cluster_id: str
:return: A dictionary representing the job
:rtype: dict
"""
try:
return self.get_conn().describe_job_run(id=job_id, virtualClusterId=cluster_id)
except botocore.exceptions.ClientError as err:
error_response = err.response.get("Error", {})
self.handle_aws_client_error(error_response)
if error_response.get("Code") in ("ValidationException", "InternalServerException"):
raise AirflowException(error_response.get("Message"))

def start_job(
self,
cluster_id: str,
execution_role_arn: str,
emr_release_label: str,
job_driver: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, str]:
"""Starts a spark job using EMR in EKS

:param cluster_id: The ID of the virtual cluster for which the job run is submitted
:type cluster_id: str
:param execution_role_arn: The execution role ARN for the job run.
:type execution_role_arn: str
:param emr_release_label: The Amazon EMR release version to use for the job run.
:type emr_release_label: str
:param job_driver: The job driver for the job run.
:type job_driver: dict
:param configuration_overrides: The configuration overrides for the job run.
:type configuration_overrides: dict
:param tags: The tags assigned to job runs
:type tags: dict
:param name: The name of the job run.
:type name: str
:param client_token: The client idempotency token of the job run request. Provided if not populated
:type client_token: str
:return: A response with job run details
:rtype: dict

"""
params = {
"virtualClusterId": cluster_id,
"executionRoleArn": execution_role_arn,
"releaseLabel": emr_release_label,
"jobDriver": job_driver,
}
optional_params = (
("configurationOverrides", "configuration_overrides"),
("name", "name"),
("clientToken", "client_token"),
("tags", "tags"),
)
for aws_var_name, airflow_var_name in optional_params:
if kwargs.get(airflow_var_name):
params[aws_var_name] = kwargs[airflow_var_name]

try:
return self.get_conn().start_job_run(**params)
except botocore.exceptions.ClientError as err:
error_response = err.response.get("Error", {})
self.handle_aws_client_error(error_response)
raise AirflowException(error_response.get("Message"))

def terminate_job_by_id(self, job_id: str, cluster_id: str) -> Optional[dict]:
"""Terminates a job by job id and virtual cluster id.

If the job is found, returns a response with job id and cluster id.

:param job_id: The ID of the job run request
:type job_id: str
:param cluster_id: The ID of the virtual cluster for which the job run is submitted
:type cluster_id: str
:return: A dictionary representing the job
:rtype: dict
"""
try:
return self.get_conn().cancel_job_run(id=job_id, virtualClusterId=cluster_id)
except botocore.exceptions.ClientError as err:
error_response = err.response.get("Error", {})
self.handle_aws_client_error(error_response)
raise AirflowException(error_response.get("Message"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr_containers import EmrContainersHook


class EmrContainersCancelJobOperator(BaseOperator):
"""Operator to a cancel a job run.

Cancels a job run. A job run is a unit of work,
such as a Spark jar, PySpark script, or SparkSQL query, that you submit to Amazon EMR on EKS.

:param job_id: The ID of the job run to cancel.
:type job_id: str
:param cluster_id: The ID of the virtual cluster for which the job run will be canceled
:type cluster_id: str
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""

ui_color = '#f9c915'

def __init__(self, *, job_id: str, cluster_id: str, aws_conn_id: str = 'aws_default', **kwargs):
super().__init__(**kwargs)
self.job_id = job_id
self.cluster_id = cluster_id
self.aws_conn_id = aws_conn_id

def execute(self, context: Dict[str, Any]) -> str:
emr_containers = EmrContainersHook(aws_conn_id=self.aws_conn_id)

self.log.info('Cancelling EMR Containers job %s', self.job_id)
response = emr_containers.terminate_job_by_id(job_id=self.job_id, cluster_id=self.cluster_id)
self.log.info('EMR Containers job %s has been cancelled', response["id"])
return response["id"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr_containers import EmrContainersHook


class EmrContainersGetJobStateOperator(BaseOperator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious what the use case is for an operator that just returns the job state?

Apologies if this is a silly question - still getting up to speed on typical Airflow patterns.

Copy link
Author

@wanderijames wanderijames Aug 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you asked me that question, I don't think this has a straight forward use case as an operator.

"""Operator to get a job status.

A job run is a unit of work, such as a Spark jar, PySpark script, or SparkSQL query,
that you submit to Amazon EMR on EKS.

A job is in PENDING, SUBMITTED, RUNNING, FAILED, CANCELLED or CANCEL state

:param job_id: The ID of the job run request
:type job_id: str
:param cluster_id: The ID of the virtual cluster for which the job run is submitted
:type cluster_id: str
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""

ui_color = '#f9c915'

def __init__(
self, *, job_id: str, cluster_id: str, aws_conn_id: str = 'aws_default', **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.job_id = job_id
self.cluster_id = cluster_id
self.aws_conn_id = aws_conn_id

def execute(self, context: Dict[str, Any]) -> str:
"""Check a job state in EMR EKS

:return: A job state
:rtype: str
"""
emr_containers = EmrContainersHook(aws_conn_id=self.aws_conn_id)

self.log.info('Checking job %s state in cluster %s', self.job_id, self.cluster_id)
response = emr_containers.get_job_by_id(self.job_id, self.cluster_id)
return response["jobRun"]["state"]
95 changes: 95 additions & 0 deletions airflow/providers/amazon/aws/operators/emr_containers_start_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr_containers import EmrContainersHook


class EmrContainersStartJobOperator(BaseOperator):
"""Operator to start a job.

A job run is a unit of work, such as a Spark jar, PySpark script, or SparkSQL query,
that you submit to Amazon EMR on EKS.

:param cluster_id: The ID of the virtual cluster for which the job run is submitted
:type cluster_id: str
:param execution_role_arn: The execution role ARN for the job run.
:type execution_role_arn: str
:param emr_release_label: The Amazon EMR release version to use for the job run.
:type emr_release_label: str
:param job_driver: The job driver for the job run.
:type job_driver: dict
:param configuration_overrides: The configuration overrides for the job run.
:type configuration_overrides: dict
:param tags: The tags assigned to job runs
:type tags: dict
:param name: The name of the job run.
:type name: str
:param client_token: The client idempotency token of the job run request. Provided if not populated
:type client_token: str
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""

ui_color = '#f9c915'

def __init__(
self,
*,
cluster_id: str,
execution_role_arn: str,
emr_release_label: str,
job_driver: Dict[str, Any],
configuration_overrides: Optional[dict] = None,
tags: Optional[dict] = None,
name: Optional[str] = None,
client_token: Optional[str] = None,
aws_conn_id: str = 'aws_default',
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.start_job_params = dict(
cluster_id=cluster_id,
execution_role_arn=execution_role_arn,
emr_release_label=emr_release_label,
job_driver=job_driver,
configuration_overrides=configuration_overrides,
tags=tags,
name=name,
client_token=client_token,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see you defining client_token anywhere. Have you tested this in a live environment to validate that it works?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is being used in line 89;

response = emr_containers.start_job(**self.start_job_params)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but it's not initialized anywhere. The default is None, which means if the client doesn't pass in a client_token, a new job won't get created after the first one. I ran into a similar issue and had to generate a UUID as the default:

client_token or str(uuid4())

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dacort for this.

)
self.aws_conn_id = aws_conn_id

def execute(self, context: Dict[str, Any]) -> str:
"""Start a job in EMR EKS

:return: A job id
:rtype: str
"""
emr_containers = EmrContainersHook(aws_conn_id=self.aws_conn_id)

self.log.info('Starting job in EMR Containers')
response = emr_containers.start_job(**self.start_job_params)
self.log.info(
'Job %s has been started in EMR Containers in cluster %s',
response["id"],
response["virtualClusterId"],
)
return response["id"]
13 changes: 13 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ integrations:
how-to-guide:
- /docs/apache-airflow-providers-amazon/operators/dms.rst
tags: [aws]
- integration-name: Amazon EMR on Amazon EKS
external-doc-url: https://aws.amazon.com/emr/https://aws.amazon.com/emr/features/eks/
logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png
tags: [aws]

operators:
- integration-name: Amazon Athena
Expand Down Expand Up @@ -214,6 +218,11 @@ operators:
python-modules:
- airflow.providers.amazon.aws.operators.step_function_get_execution_output
- airflow.providers.amazon.aws.operators.step_function_start_execution
- integration-name: Amazon EMR on Amazon EKS
python-modules:
- airflow.providers.amazon.aws.operators.emr_containers_cancel_job
- airflow.providers.amazon.aws.operators.emr_containers_get_job_state
- airflow.providers.amazon.aws.operators.emr_containers_start_job

sensors:
- integration-name: Amazon Athena
Expand Down Expand Up @@ -337,6 +346,9 @@ hooks:
- integration-name: AWS Step Functions
python-modules:
- airflow.providers.amazon.aws.hooks.step_function
- integration-name: Amazon EMR on Amazon EKS
python-modules:
- airflow.providers.amazon.aws.hooks.emr_containers

transfers:
- source-integration-name: Amazon DynamoDB
Expand Down Expand Up @@ -399,3 +411,4 @@ hook-class-names:
- airflow.providers.amazon.aws.hooks.s3.S3Hook
- airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook
- airflow.providers.amazon.aws.hooks.emr.EmrHook
- airflow.providers.amazon.aws.hooks.emr_containers.EmrContainersHook
Loading