Skip to content

Commit 9103ea1

Browse files
authoredNov 16, 2023
Add support for Spark Connect to pyspark decorator (#35665)
* Add support for Spark Connect to pyspark decorator In Apache Spark 3.4 Spark Connect was introduced which allows remote connectivity to remote Spark Cluster using the DataFrame API.
1 parent e29464b commit 9103ea1

File tree

8 files changed

+339
-15
lines changed

8 files changed

+339
-15
lines changed
 

‎airflow/providers/apache/spark/decorators/pyspark.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
2424
from airflow.hooks.base import BaseHook
2525
from airflow.operators.python import PythonOperator
26+
from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
2627

2728
if TYPE_CHECKING:
2829
from airflow.utils.context import Context
@@ -73,34 +74,44 @@ def execute(self, context: Context):
7374
from pyspark import SparkConf
7475
from pyspark.sql import SparkSession
7576

76-
conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}")
77+
conf = SparkConf()
78+
conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")
7779

78-
master = "local[*]"
80+
url = "local[*]"
7981
if self.conn_id:
82+
# we handle both spark connect and spark standalone
8083
conn = BaseHook.get_connection(self.conn_id)
81-
if conn.port:
82-
master = f"{conn.host}:{conn.port}"
84+
if conn.conn_type == SparkConnectHook.conn_type:
85+
url = SparkConnectHook(self.conn_id).get_connection_url()
86+
elif conn.port:
87+
url = f"{conn.host}:{conn.port}"
8388
elif conn.host:
84-
master = conn.host
89+
url = conn.host
8590

8691
for key, value in conn.extra_dejson.items():
8792
conf.set(key, value)
8893

89-
conf.setMaster(master)
94+
# you cannot have both remote and master
95+
if url.startswith("sc://"):
96+
conf.set("spark.remote", url)
9097

9198
# task can override connection config
9299
for key, value in self.config_kwargs.items():
93100
conf.set(key, value)
94101

102+
if not conf.get("spark.remote") and not conf.get("spark.master"):
103+
conf.set("spark.master", url)
104+
95105
spark = SparkSession.builder.config(conf=conf).getOrCreate()
96-
sc = spark.sparkContext
97106

98107
if not self.op_kwargs:
99108
self.op_kwargs = {}
100109

101110
op_kwargs: dict[str, Any] = dict(self.op_kwargs)
102111
op_kwargs["spark"] = spark
103-
op_kwargs["sc"] = sc
112+
113+
# spark context is not available when using spark connect
114+
op_kwargs["sc"] = spark.sparkContext if not conf.get("spark.remote") else None
104115

105116
self.op_kwargs = op_kwargs
106117
return super().execute(context)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from typing import Any
21+
from urllib.parse import quote, urlparse, urlunparse
22+
23+
from airflow.hooks.base import BaseHook
24+
from airflow.utils.log.logging_mixin import LoggingMixin
25+
26+
27+
class SparkConnectHook(BaseHook, LoggingMixin):
28+
"""Hook for Spark Connect."""
29+
30+
# from pyspark's ChannelBuilder
31+
PARAM_USE_SSL = "use_ssl"
32+
PARAM_TOKEN = "token"
33+
PARAM_USER_ID = "user_id"
34+
35+
conn_name_attr = "conn_id"
36+
default_conn_name = "spark_connect_default"
37+
conn_type = "spark_connect"
38+
hook_name = "Spark Connect"
39+
40+
@staticmethod
41+
def get_ui_field_behaviour() -> dict[str, Any]:
42+
"""Return custom field behaviour."""
43+
return {
44+
"hidden_fields": [
45+
"schema",
46+
],
47+
"relabeling": {"password": "Token", "login": "User ID"},
48+
}
49+
50+
@staticmethod
51+
def get_connection_form_widgets() -> dict[str, Any]:
52+
"""Returns connection widgets to add to connection form."""
53+
from flask_babel import lazy_gettext
54+
from wtforms import BooleanField
55+
56+
return {
57+
SparkConnectHook.PARAM_USE_SSL: BooleanField(lazy_gettext("Use SSL"), default=False),
58+
}
59+
60+
def __init__(self, conn_id: str = default_conn_name) -> None:
61+
super().__init__()
62+
self._conn_id = conn_id
63+
64+
def get_connection_url(self) -> str:
65+
conn = self.get_connection(self._conn_id)
66+
67+
host = conn.host
68+
if conn.host.find("://") == -1:
69+
host = f"sc://{conn.host}"
70+
if conn.port:
71+
host = f"{conn.host}:{conn.port}"
72+
73+
url = urlparse(host)
74+
75+
if url.path:
76+
raise ValueError("Path {url.path} is not supported in Spark Connect connection URL")
77+
78+
params = []
79+
80+
if conn.login:
81+
params.append(f"{SparkConnectHook.PARAM_USER_ID}={quote(conn.login)}")
82+
83+
if conn.password:
84+
params.append(f"{SparkConnectHook.PARAM_TOKEN}={quote(conn.password)}")
85+
86+
use_ssl = conn.extra_dejson.get(SparkConnectHook.PARAM_USE_SSL)
87+
if use_ssl is not None:
88+
params.append(f"{SparkConnectHook.PARAM_USE_SSL}={quote(str(use_ssl))}")
89+
90+
return urlunparse(
91+
(
92+
"sc",
93+
url.netloc,
94+
"/",
95+
";".join(params), # params
96+
"",
97+
url.fragment,
98+
)
99+
)

‎airflow/providers/apache/spark/provider.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ versions:
5151
dependencies:
5252
- apache-airflow>=2.5.0
5353
- pyspark
54+
- grpcio-status
5455

5556
integrations:
5657
- integration-name: Apache Spark
@@ -70,13 +71,16 @@ operators:
7071
hooks:
7172
- integration-name: Apache Spark
7273
python-modules:
74+
- airflow.providers.apache.spark.hooks.spark_connect
7375
- airflow.providers.apache.spark.hooks.spark_jdbc
7476
- airflow.providers.apache.spark.hooks.spark_jdbc_script
7577
- airflow.providers.apache.spark.hooks.spark_sql
7678
- airflow.providers.apache.spark.hooks.spark_submit
7779

7880

7981
connection-types:
82+
- hook-class-name: airflow.providers.apache.spark.hooks.spark_connect.SparkConnectHook
83+
connection-type: spark_connect
8084
- hook-class-name: airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook
8185
connection-type: spark_jdbc
8286
- hook-class-name: airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook

‎docs/apache-airflow-providers-apache-spark/connections/spark.rst

+10-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The Apache Spark connection type enables connection to Apache Spark.
2727
Default Connection IDs
2828
----------------------
2929

30-
Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by default. Spark SQL hooks and operators point to ``spark_sql_default`` by default.
30+
Spark Submit and Spark JDBC hooks and operators use ``spark_default`` by default. Spark SQL hooks and operators point to ``spark_sql_default`` by default. The Spark Connect hook uses ``spark_connect_default`` by default.
3131

3232
Configuring the Connection
3333
--------------------------
@@ -45,6 +45,15 @@ Extra (optional)
4545
* ``spark-binary`` - The command to use for Spark submit. Some distros may use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit``, ``spark2-submit`` or ``spark3-submit`` are allowed as value.
4646
* ``namespace`` - Kubernetes namespace (``spark.kubernetes.namespace``) to divide cluster resources between multiple users (via resource quota).
4747

48+
User ID (optional, only applies to Spark Connect)
49+
The user ID to authenticate with the proxy.
50+
51+
Token (optional, only applies to Spark Connect)
52+
The token to authenticate with the proxy.
53+
54+
Use SSL (optional, only applies to Spark Connect)
55+
Whether to use SSL when connecting.
56+
4857
When specifying the connection in environment variable you should specify
4958
it using URI syntax.
5059

‎docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst

+22-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ PySpark Decorator
2323
=================
2424

2525
Python callable wrapped within the ``@task.pyspark`` decorator
26-
is injected with a SparkContext object.
26+
is injected with a SparkSession and SparkContext object if available.
2727

2828
Parameters
2929
----------
@@ -49,3 +49,24 @@ that the ``spark`` and ``sc`` objects are injected into the function.
4949
:dedent: 4
5050
:start-after: [START task_pyspark]
5151
:end-before: [END task_pyspark]
52+
53+
54+
Spark Connect
55+
-------------
56+
57+
In `Apache Spark 3.4 <https://spark.apache.org/docs/latest/spark-connect-overview.html>`_,
58+
Spark Connect introduced a decoupled client-server architecture
59+
that allows remote connectivity to Spark clusters using the DataFrame API. Using
60+
Spark Connect is the preferred way in Airflow to make use of the PySpark decorator,
61+
because it does not require to run the Spark driver on the same host as Airflow.
62+
To make use of Spark Connect, you prepend your host url with ``sc://``. For example,
63+
``sc://spark-cluster:15002``.
64+
65+
66+
Authentication
67+
^^^^^^^^^^^^^^
68+
69+
Spark Connect does not have built-in authentication. The gRPC HTTP/2 interface however
70+
allows the use of authentication to communicate with the Spark Connect server through
71+
authenticating proxies. To make use of authentication make sure to create a ``Spark Connect``
72+
connection and set the right credentials.

‎generated/provider_dependencies.json

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@
190190
"apache.spark": {
191191
"deps": [
192192
"apache-airflow>=2.5.0",
193+
"grpcio-status",
193194
"pyspark"
194195
],
195196
"cross-providers-deps": [

‎tests/providers/apache/spark/decorators/test_pyspark.py

+112-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from typing import Any
1920
from unittest import mock
2021

2122
import pytest
@@ -27,6 +28,22 @@
2728
DEFAULT_DATE = timezone.datetime(2021, 9, 1)
2829

2930

31+
class FakeConfig:
32+
data: dict[str, Any]
33+
34+
def __init__(self, data: dict[str, Any] | None = None):
35+
if data:
36+
self.data = data
37+
else:
38+
self.data = {}
39+
40+
def get(self, key: str, default: Any = None) -> Any:
41+
return self.data.get(key, default)
42+
43+
def set(self, key: str, value: Any) -> None:
44+
self.data[key] = value
45+
46+
3047
class TestPysparkDecorator:
3148
def setup_method(self):
3249
db.merge_conn(
@@ -38,14 +55,47 @@ def setup_method(self):
3855
)
3956
)
4057

58+
db.merge_conn(
59+
Connection(
60+
conn_id="spark-connect",
61+
conn_type="spark",
62+
host="sc://localhost",
63+
extra="",
64+
)
65+
)
66+
67+
db.merge_conn(
68+
Connection(
69+
conn_id="spark-connect-auth",
70+
conn_type="spark_connect",
71+
host="sc://localhost",
72+
password="1234",
73+
login="connect",
74+
extra={
75+
"use_ssl": True,
76+
},
77+
)
78+
)
79+
4180
@pytest.mark.db_test
42-
@mock.patch("pyspark.SparkConf.setAppName")
81+
@mock.patch("pyspark.SparkConf")
4382
@mock.patch("pyspark.sql.SparkSession")
4483
def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, dag_maker):
84+
config = FakeConfig()
85+
86+
builder = mock.MagicMock()
87+
spark_mock.builder.config.return_value = builder
88+
builder.getOrCreate.return_value = builder
89+
builder.sparkContext.return_value = builder
90+
91+
conf_mock.return_value = config
92+
4593
@task.pyspark(conn_id="pyspark_local", config_kwargs={"spark.executor.memory": "2g"})
4694
def f(spark, sc):
4795
import random
4896

97+
assert spark is not None
98+
assert sc is not None
4999
return [random.random() for _ in range(100)]
50100

51101
with dag_maker():
@@ -55,14 +105,20 @@ def f(spark, sc):
55105
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
56106
ti = dr.get_task_instances()[0]
57107
assert len(ti.xcom_pull()) == 100
58-
conf_mock().set.assert_called_with("spark.executor.memory", "2g")
59-
conf_mock().setMaster.assert_called_once_with("spark://none")
108+
assert config.get("spark.master") == "spark://none"
109+
assert config.get("spark.executor.memory") == "2g"
110+
assert config.get("spark.remote") is None
111+
assert config.get("spark.app.name")
112+
60113
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
61114

62115
@pytest.mark.db_test
63-
@mock.patch("pyspark.SparkConf.setAppName")
116+
@mock.patch("pyspark.SparkConf")
64117
@mock.patch("pyspark.sql.SparkSession")
65118
def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
119+
config = FakeConfig()
120+
conf_mock.return_value = config
121+
66122
e = 2
67123

68124
@task.pyspark
@@ -76,5 +132,56 @@ def f():
76132
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
77133
ti = dr.get_task_instances()[0]
78134
assert ti.xcom_pull() == e
79-
conf_mock().setMaster.assert_called_once_with("local[*]")
135+
assert config.get("spark.master") == "local[*]"
80136
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
137+
138+
@pytest.mark.db_test
139+
@mock.patch("pyspark.SparkConf")
140+
@mock.patch("pyspark.sql.SparkSession")
141+
def test_spark_connect(self, spark_mock, conf_mock, dag_maker):
142+
config = FakeConfig()
143+
conf_mock.return_value = config
144+
145+
@task.pyspark(conn_id="spark-connect")
146+
def f(spark, sc):
147+
assert spark is not None
148+
assert sc is None
149+
150+
return True
151+
152+
with dag_maker():
153+
ret = f()
154+
155+
dr = dag_maker.create_dagrun()
156+
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
157+
ti = dr.get_task_instances()[0]
158+
assert ti.xcom_pull()
159+
assert config.get("spark.remote") == "sc://localhost"
160+
assert config.get("spark.master") is None
161+
assert config.get("spark.app.name")
162+
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
163+
164+
@pytest.mark.db_test
165+
@mock.patch("pyspark.SparkConf")
166+
@mock.patch("pyspark.sql.SparkSession")
167+
def test_spark_connect_auth(self, spark_mock, conf_mock, dag_maker):
168+
config = FakeConfig()
169+
conf_mock.return_value = config
170+
171+
@task.pyspark(conn_id="spark-connect-auth")
172+
def f(spark, sc):
173+
assert spark is not None
174+
assert sc is None
175+
176+
return True
177+
178+
with dag_maker():
179+
ret = f()
180+
181+
dr = dag_maker.create_dagrun()
182+
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
183+
ti = dr.get_task_instances()[0]
184+
assert ti.xcom_pull()
185+
assert config.get("spark.remote") == "sc://localhost/;user_id=connect;token=1234;use_ssl=True"
186+
assert config.get("spark.master") is None
187+
assert config.get("spark.app.name")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import pytest
21+
22+
from airflow.models import Connection
23+
from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
24+
from airflow.utils import db
25+
26+
pytestmark = pytest.mark.db_test
27+
28+
29+
class TestSparkConnectHook:
30+
def setup_method(self):
31+
db.merge_conn(
32+
Connection(
33+
conn_id="spark-default",
34+
conn_type="spark_connect",
35+
host="sc://spark-host",
36+
port=1000,
37+
login="spark-user",
38+
password="1234",
39+
extra='{"queue": "root.etl", "deploy-mode": "cluster"}',
40+
)
41+
)
42+
43+
db.merge_conn(
44+
Connection(
45+
conn_id="spark-test",
46+
conn_type="spark_connect",
47+
host="nowhere",
48+
login="spark-user",
49+
)
50+
)
51+
52+
db.merge_conn(
53+
Connection(
54+
conn_id="spark-app",
55+
conn_type="spark_connect",
56+
host="sc://cluster/app",
57+
login="spark-user",
58+
)
59+
)
60+
61+
def test_get_connection_url(self):
62+
expected_url = "sc://spark-host:1000/;user_id=spark-user;token=1234"
63+
hook = SparkConnectHook(conn_id="spark-default")
64+
assert hook.get_connection_url() == expected_url
65+
66+
expected_url = "sc://nowhere/;user_id=spark-user"
67+
hook = SparkConnectHook(conn_id="spark-test")
68+
assert hook.get_connection_url() == expected_url
69+
70+
hook = SparkConnectHook(conn_id="spark-app")
71+
with pytest.raises(ValueError):
72+
hook.get_connection_url()

0 commit comments

Comments
 (0)
Please sign in to comment.