Skip to content

Commit ed7fa70

Browse files
dungdm93shcoderAlex
authored andcommitted
feat: Trino Authentications (apache#17593)
* feat: support Trino Authentications Signed-off-by: Đặng Minh Dũng <dungdm93@live.com> * docs: Trino Authentications Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
1 parent 3f0ec34 commit ed7fa70

File tree

9 files changed

+286
-23
lines changed

9 files changed

+286
-23
lines changed

docs/src/pages/docs/Connecting to Databases/trino.mdx

+83-8
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,95 @@ version: 1
88

99
## Trino
1010

11-
Supported trino version 352 and higher
12-
13-
The [sqlalchemy-trino](https://pypi.org/project/sqlalchemy-trino/) library is the recommended way to connect to Trino through SQLAlchemy.
14-
15-
The expected connection string is formatted as follows:
11+
Superset supports Trino >=352 via SQLAlchemy by using the [sqlalchemy-trino](https://pypi.org/project/sqlalchemy-trino/) library.
1612

13+
### Connection String
14+
The connection string format is as follows:
1715
```
1816
trino://{username}:{password}@{hostname}:{port}/{catalog}
1917
```
20-
If you are running trino with docker on local machine please use the following connection URL
2118

19+
If you are running Trino with docker on local machine, please use the following connection URL
2220
```
2321
trino://trino@host.docker.internal:8080
2422
```
2523

26-
Reference:
27-
[Trino-Superset-Podcast](https://trino.io/episodes/12.html)
24+
### Authentications
25+
#### 1. Basic Authentication
26+
You can provide `username`/`password` in the connection string or in the `Secure Extra` field at `Advanced / Security`
27+
* In Connection String
28+
```
29+
trino://{username}:{password}@{hostname}:{port}/{catalog}
30+
```
31+
32+
* In `Secure Extra` field
33+
```json
34+
{
35+
"auth_method": "basic",
36+
"auth_params": {
37+
"username": "<username>",
38+
"password": "<password>"
39+
}
40+
}
41+
```
42+
43+
NOTE: if both are provided, `Secure Extra` always takes higher priority.
44+
45+
#### 2. Kerberos Authentication
46+
In `Secure Extra` field, config as following example:
47+
```json
48+
{
49+
"auth_method": "kerberos",
50+
"auth_params": {
51+
"service_name": "superset",
52+
"config": "/path/to/krb5.config",
53+
...
54+
}
55+
}
56+
```
57+
58+
All fields in `auth_params` are passed directly to the [`KerberosAuthentication`](https://github.com/trinodb/trino-python-client/blob/0.306.0/trino/auth.py#L40) class.
59+
60+
#### 3. JWT Authentication
61+
Config `auth_method` and provide token in `Secure Extra` field
62+
```json
63+
{
64+
"auth_method": "jwt",
65+
"auth_params": {
66+
"token": "<your-jwt-token>"
67+
}
68+
}
69+
```
70+
71+
#### 4. Custom Authentication
72+
To use custom authentication, first you need to add it into
73+
`ALLOWED_EXTRA_AUTHENTICATIONS` allow list in Superset config file:
74+
```python
75+
from your.module import AuthClass
76+
from another.extra import auth_method
77+
78+
ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {
79+
"trino": {
80+
"custom_auth": AuthClass,
81+
"another_auth_method": auth_method,
82+
},
83+
}
84+
```
85+
86+
Then in `Secure Extra` field:
87+
```json
88+
{
89+
"auth_method": "custom_auth",
90+
"auth_params": {
91+
...
92+
}
93+
}
94+
```
95+
96+
You can also use custom authentication by providing reference to your `trino.auth.Authentication` class
97+
or factory function (which returns an `Authentication` instance) to `auth_method`.
98+
99+
All fields in `auth_params` are passed directly to your class/function.
100+
101+
**Reference**:
102+
* [Trino-Superset-Podcast](https://trino.io/episodes/12.html)

requirements/base.txt

+2
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ werkzeug==1.0.1
283283
# via
284284
# flask
285285
# flask-jwt-extended
286+
wrapt==1.12.1
287+
# via -r requirements/base.in
286288
wtforms==2.3.3
287289
# via
288290
# flask-appbuilder

requirements/testing.in

+1
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ statsd
3838
pytest-mock
3939
# DB dependencies
4040
-e file:.[bigquery]
41+
-e file:.[trino]

requirements/testing.txt

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SHA1:9658361c2ab00a6b27c5875b7b3557c2999854ba
1+
# SHA1:7a8e256097b4758bdeda2529d3d4d31e421e1a3c
22
#
33
# This file is autogenerated by pip-compile-multi
44
# To update, run:
@@ -11,8 +11,6 @@
1111
# via
1212
# -r requirements/base.in
1313
# -r requirements/testing.in
14-
appnope==0.1.2
15-
# via ipython
1614
astroid==2.6.6
1715
# via pylint
1816
backcall==0.2.0
@@ -166,20 +164,22 @@ requests-oauthlib==1.3.0
166164
# via google-auth-oauthlib
167165
rsa==4.7.2
168166
# via google-auth
167+
sqlalchemy-trino==0.4.1
168+
# via apache-superset
169169
statsd==3.3.0
170170
# via -r requirements/testing.in
171171
traitlets==5.0.5
172172
# via
173173
# ipython
174174
# matplotlib-inline
175+
trino==0.306
176+
# via sqlalchemy-trino
175177
typing-inspect==0.7.1
176178
# via libcst
177179
wcwidth==0.2.5
178180
# via prompt-toolkit
179181
websocket-client==1.2.0
180182
# via docker
181-
wrapt==1.12.1
182-
# via astroid
183183

184184
# The following packages are considered to be unsafe in a requirements file:
185185
# pip

superset/config.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
723723
# Force refresh while auto-refresh in dashboard
724724
DASHBOARD_AUTO_REFRESH_MODE: Literal["fetch", "force"] = "force"
725725

726+
726727
# Default celery config is to use SQLA as a broker, in a production setting
727728
# you'll want to use a proper broker as specified here:
728729
# http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html
@@ -872,6 +873,8 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
872873
# The directory within the bucket specified above that will
873874
# contain all the external tables
874875
CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
876+
877+
875878
# Function that creates upload directory dynamically based on the
876879
# database used, user and schema provided.
877880
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
@@ -986,6 +989,19 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
986989
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
987990
PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds())
988991

992+
# Allow list of custom authentications for each DB engine.
993+
# Example:
994+
# from your.module import AuthClass
995+
# from another.extra import auth_method
996+
#
997+
# ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {
998+
# "trino": {
999+
# "custom_auth": AuthClass,
1000+
# "another_auth_method": auth_method,
1001+
# },
1002+
# }
1003+
ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {}
1004+
9891005
# Allow for javascript controls components
9901006
# this enables programmers to customize certain charts (like the
9911007
# geospatial ones) by inputing javascript in controls. This exposes
@@ -1012,6 +1028,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
10121028
# as such `create_engine(url, **params)`
10131029
DB_CONNECTION_MUTATOR = None
10141030

1031+
10151032
# A function that intercepts the SQL to be executed and can alter it.
10161033
# The use case is can be around adding some sort of comment header
10171034
# with information such as the username and worker node information
@@ -1323,8 +1340,8 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument
13231340
if CONFIG_PATH_ENV_VAR in os.environ:
13241341
# Explicitly import config module that is not necessarily in pythonpath; useful
13251342
# for case where app is being executed via pex.
1343+
cfg_path = os.environ[CONFIG_PATH_ENV_VAR]
13261344
try:
1327-
cfg_path = os.environ[CONFIG_PATH_ENV_VAR]
13281345
module = sys.modules[__name__]
13291346
override_conf = imp.load_source("superset_config", cfg_path)
13301347
for key in dir(override_conf):
@@ -1339,8 +1356,9 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument
13391356
raise
13401357
elif importlib.util.find_spec("superset_config") and not is_test():
13411358
try:
1342-
import superset_config # pylint: disable=import-error
1343-
from superset_config import * # type: ignore # pylint: disable=import-error,wildcard-import,unused-wildcard-import
1359+
# pylint: disable=import-error,wildcard-import,unused-wildcard-import
1360+
import superset_config
1361+
from superset_config import * # type:ignore
13441362

13451363
print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]")
13461364
except Exception:

superset/db_engine_specs/base.py

+20
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,26 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
12731273
raise ex
12741274
return extra
12751275

1276+
@staticmethod
1277+
def update_encrypted_extra_params(
1278+
database: "Database", params: Dict[str, Any]
1279+
) -> None:
1280+
"""
1281+
Some databases require some sensitive information which do not conform to
1282+
the username:password syntax normally used by SQLAlchemy.
1283+
1284+
:param database: database instance from which to extract extras
1285+
:param params: params to be updated
1286+
"""
1287+
if not database.encrypted_extra:
1288+
return
1289+
try:
1290+
encrypted_extra = json.loads(database.encrypted_extra)
1291+
params.update(encrypted_extra)
1292+
except json.JSONDecodeError as ex:
1293+
logger.error(ex, exc_info=True)
1294+
raise ex
1295+
12761296
@classmethod
12771297
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
12781298
"""Pessimistic readonly, 100% sure statement won't mutate anything"""

superset/db_engine_specs/trino.py

+43
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import logging
1718
from datetime import datetime
1819
from typing import Any, Dict, List, Optional, TYPE_CHECKING
1920
from urllib import parse
2021

2122
import simplejson as json
23+
from flask import current_app
2224
from sqlalchemy.engine.url import make_url, URL
2325

2426
from superset.db_engine_specs.base import BaseEngineSpec
@@ -27,6 +29,8 @@
2729
if TYPE_CHECKING:
2830
from superset.models.core import Database
2931

32+
logger = logging.getLogger(__name__)
33+
3034

3135
class TrinoEngineSpec(BaseEngineSpec):
3236
engine = "trino"
@@ -202,3 +206,42 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
202206
connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert)
203207

204208
return extra
209+
210+
@staticmethod
211+
def update_encrypted_extra_params(
212+
database: "Database", params: Dict[str, Any]
213+
) -> None:
214+
if not database.encrypted_extra:
215+
return
216+
try:
217+
encrypted_extra = json.loads(database.encrypted_extra)
218+
auth_method = encrypted_extra.pop("auth_method", None)
219+
auth_params = encrypted_extra.pop("auth_params", {})
220+
if not auth_method:
221+
return
222+
223+
connect_args = params.setdefault("connect_args", {})
224+
connect_args["http_scheme"] = "https"
225+
# pylint: disable=import-outside-toplevel
226+
if auth_method == "basic":
227+
from trino.auth import BasicAuthentication as trino_auth # noqa
228+
elif auth_method == "kerberos":
229+
from trino.auth import KerberosAuthentication as trino_auth # noqa
230+
elif auth_method == "jwt":
231+
from trino.auth import JWTAuthentication as trino_auth # noqa
232+
else:
233+
allowed_extra_auths = current_app.config[
234+
"ALLOWED_EXTRA_AUTHENTICATIONS"
235+
].get("trino", {})
236+
if auth_method in allowed_extra_auths:
237+
trino_auth = allowed_extra_auths.get(auth_method)
238+
else:
239+
raise ValueError(
240+
f"For security reason, custom authentication '{auth_method}' "
241+
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
242+
)
243+
244+
connect_args["auth"] = trino_auth(**auth_params)
245+
except json.JSONDecodeError as ex:
246+
logger.error(ex, exc_info=True)
247+
raise ex

superset/models/core.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ class KeyValue(Model): # pylint: disable=too-few-public-methods
9191

9292

9393
class CssTemplate(Model, AuditMixinNullable):
94-
9594
"""CSS templates for dashboards"""
9695

9796
__tablename__ = "css_templates"
@@ -244,7 +243,10 @@ def parameters(self) -> Dict[str, Any]:
244243
uri = make_url(self.sqlalchemy_uri_decrypted)
245244
encrypted_extra = self.get_encrypted_extra()
246245
try:
247-
parameters = self.db_engine_spec.get_parameters_from_uri(uri, encrypted_extra=encrypted_extra) # type: ignore # pylint: disable=line-too-long,useless-suppression
246+
# pylint: disable=useless-suppression
247+
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
248+
uri, encrypted_extra=encrypted_extra
249+
)
248250
except Exception: # pylint: disable=broad-except
249251
parameters = {}
250252

@@ -330,7 +332,14 @@ def get_effective_user(
330332
effective_username = g.user.username
331333
return effective_username
332334

333-
@memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
335+
@memoized(
336+
watch=(
337+
"impersonate_user",
338+
"sqlalchemy_uri_decrypted",
339+
"extra",
340+
"encrypted_extra",
341+
)
342+
)
334343
def get_sqla_engine(
335344
self,
336345
schema: Optional[str] = None,
@@ -365,7 +374,7 @@ def get_sqla_engine(
365374
if connect_args:
366375
params["connect_args"] = connect_args
367376

368-
params.update(self.get_encrypted_extra())
377+
self.update_encrypted_extra_params(params)
369378

370379
if DB_CONNECTION_MUTATOR:
371380
if not source and request and request.referrer:
@@ -443,9 +452,8 @@ def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
443452

444453
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
445454

446-
if (
447-
engine.dialect.identifier_preparer._double_percents # pylint: disable=protected-access
448-
):
455+
# pylint: disable=protected-access
456+
if engine.dialect.identifier_preparer._double_percents: # noqa
449457
sql = sql.replace("%%", "%")
450458

451459
return sql
@@ -639,6 +647,9 @@ def get_encrypted_extra(self) -> Dict[str, Any]:
639647
raise ex
640648
return encrypted_extra
641649

650+
def update_encrypted_extra_params(self, params: Dict[str, Any]) -> None:
651+
self.db_engine_spec.update_encrypted_extra_params(self, params)
652+
642653
def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
643654
extra = self.get_extra()
644655
meta = MetaData(**extra.get("metadata_params", {}))

0 commit comments

Comments
 (0)