diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 0b443e0b0a3c9..a6b27846b44ab 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -18,10 +18,12 @@ __all__ = ['SparkConf'] import sys +from typing import Dict, List, Optional, Tuple, cast, overload +from py4j.java_gateway import JVMView, JavaObject # type: ignore[import] -class SparkConf(object): +class SparkConf(object): """ Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. @@ -105,7 +107,11 @@ class SparkConf(object): spark.home=/path """ - def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): + _jconf: Optional[JavaObject] + _conf: Optional[Dict[str, str]] + + def __init__(self, loadDefaults: bool = True, _jvm: Optional[JVMView] = None, + _jconf: Optional[JavaObject] = None): """ Create a new Spark configuration. """ @@ -113,7 +119,7 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): self._jconf = _jconf else: from pyspark.context import SparkContext - _jvm = _jvm or SparkContext._jvm + _jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined] if _jvm is not None: # JVM is created, so create self._jconf directly through JVM @@ -124,48 +130,58 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): self._jconf = None self._conf = {} - def set(self, key, value): + def set(self, key: str, value: str) -> "SparkConf": """Set a configuration property.""" # Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet. if self._jconf is not None: self._jconf.set(key, str(value)) else: + assert self._conf is not None self._conf[key] = str(value) return self - def setIfMissing(self, key, value): + def setIfMissing(self, key: str, value: str) -> "SparkConf": """Set a configuration property, if not already set.""" if self.get(key) is None: self.set(key, value) return self - def setMaster(self, value): + def setMaster(self, value: str) -> "SparkConf": """Set master URL to connect to.""" self.set("spark.master", value) return self - def setAppName(self, value): + def setAppName(self, value: str) -> "SparkConf": """Set application name.""" self.set("spark.app.name", value) return self - def setSparkHome(self, value): + def setSparkHome(self, value: str) -> "SparkConf": """Set path where Spark is installed on worker nodes.""" self.set("spark.home", value) return self - def setExecutorEnv(self, key=None, value=None, pairs=None): + @overload + def setExecutorEnv(self, key: str, value: str) -> "SparkConf": + ... + + @overload + def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> "SparkConf": + ... + + def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None, + pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf": """Set an environment variable to be passed to executors.""" if (key is not None and pairs is not None) or (key is None and pairs is None): raise RuntimeError("Either pass one key-value pair or a list of pairs") elif key is not None: - self.set("spark.executorEnv." + key, value) + self.set("spark.executorEnv.{}".format(key), cast(str, value)) elif pairs is not None: for (k, v) in pairs: - self.set("spark.executorEnv." + k, v) + self.set("spark.executorEnv.{}".format(k), v) return self - def setAll(self, pairs): + def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": """ Set multiple parameters, passed as a list of key-value pairs. @@ -178,38 +194,40 @@ def setAll(self, pairs): self.set(k, v) return self - def get(self, key, defaultValue=None): + def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: """Get the configured value for some key, or return a default otherwise.""" - if defaultValue is None: # Py4J doesn't call the right get() if we pass None + if defaultValue is None: # Py4J doesn't call the right get() if we pass None if self._jconf is not None: if not self._jconf.contains(key): return None return self._jconf.get(key) else: - if key not in self._conf: - return None - return self._conf[key] + assert self._conf is not None + return self._conf.get(key, None) else: if self._jconf is not None: return self._jconf.get(key, defaultValue) else: + assert self._conf is not None return self._conf.get(key, defaultValue) - def getAll(self): + def getAll(self) -> List[Tuple[str, str]]: """Get all values as a list of key-value pairs.""" if self._jconf is not None: - return [(elem._1(), elem._2()) for elem in self._jconf.getAll()] + return [(elem._1(), elem._2()) for elem in cast(JavaObject, self._jconf).getAll()] else: - return self._conf.items() + assert self._conf is not None + return list(self._conf.items()) - def contains(self, key): + def contains(self, key: str) -> bool: """Does this configuration contain a given key?""" if self._jconf is not None: return self._jconf.contains(key) else: + assert self._conf is not None return key in self._conf - def toDebugString(self): + def toDebugString(self) -> str: """ Returns a printable version of the configuration, as a list of key=value pairs, one per line. @@ -217,10 +235,11 @@ def toDebugString(self): if self._jconf is not None: return self._jconf.toDebugString() else: + assert self._conf is not None return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items()) -def _test(): +def _test() -> None: import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: diff --git a/python/pyspark/conf.pyi b/python/pyspark/conf.pyi deleted file mode 100644 index f7ca61dea9cd2..0000000000000 --- a/python/pyspark/conf.pyi +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import List, Optional, Tuple - -from py4j.java_gateway import JVMView, JavaObject # type: ignore[import] - -class SparkConf: - def __init__( - self, - loadDefaults: bool = ..., - _jvm: Optional[JVMView] = ..., - _jconf: Optional[JavaObject] = ..., - ) -> None: ... - def set(self, key: str, value: str) -> SparkConf: ... - def setIfMissing(self, key: str, value: str) -> SparkConf: ... - def setMaster(self, value: str) -> SparkConf: ... - def setAppName(self, value: str) -> SparkConf: ... - def setSparkHome(self, value: str) -> SparkConf: ... - @overload - def setExecutorEnv(self, key: str, value: str) -> SparkConf: ... - @overload - def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> SparkConf: ... - def setAll(self, pairs: List[Tuple[str, str]]) -> SparkConf: ... - def get(self, key: str, defaultValue: Optional[str] = ...) -> str: ... - def getAll(self) -> List[Tuple[str, str]]: ... - def contains(self, key: str) -> bool: ... - def toDebugString(self) -> str: ... diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 728d658e45393..212ea9e410f78 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -610,7 +610,7 @@ def _create_shell_session() -> "SparkSession": try: # Try to access HiveConf, it will raise exception if Hive is not added conf = SparkConf() - if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + if cast(str, conf.get('spark.sql.catalogImplementation', 'hive')).lower() == 'hive': (SparkContext._jvm # type: ignore[attr-defined] .org.apache.hadoop.hive.conf.HiveConf()) return SparkSession.builder\ @@ -619,7 +619,7 @@ def _create_shell_session() -> "SparkSession": else: return SparkSession.builder.getOrCreate() except (py4j.protocol.Py4JError, TypeError): - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + if cast(str, conf.get('spark.sql.catalogImplementation', '')).lower() == 'hive': warnings.warn("Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive")