@@ -147,8 +147,11 @@ class SQLTableDataSet(AbstractDataSet):
147
147
148
148
"""
149
149
150
- DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
151
- DEFAULT_SAVE_ARGS = {"index" : False } # type: Dict[str, Any]
150
+ DEFAULT_LOAD_ARGS : Dict [str , Any ] = {}
151
+ DEFAULT_SAVE_ARGS : Dict [str , Any ] = {"index" : False }
152
+ # using Any because of Sphinx but it should be
153
+ # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
154
+ engines : Dict [str , Any ] = {}
152
155
153
156
def __init__ (
154
157
self ,
@@ -207,42 +210,50 @@ def __init__(
207
210
self ._load_args ["table_name" ] = table_name
208
211
self ._save_args ["name" ] = table_name
209
212
210
- self ._load_args ["con" ] = self ._save_args ["con" ] = credentials ["con" ]
213
+ self ._connection_str = credentials ["con" ]
214
+ self .create_connection (self ._connection_str )
215
+
216
+ @classmethod
217
+ def create_connection (cls , connection_str : str ) -> None :
218
+ """Given a connection string, create singleton connection
219
+ to be used across all instances of `SQLTableDataSet` that
220
+ need to connect to the same source.
221
+ """
222
+ if connection_str in cls .engines :
223
+ return
224
+
225
+ try :
226
+ engine = create_engine (connection_str )
227
+ except ImportError as import_error :
228
+ raise _get_missing_module_error (import_error ) from import_error
229
+ except NoSuchModuleError as exc :
230
+ raise _get_sql_alchemy_missing_error () from exc
231
+
232
+ cls .engines [connection_str ] = engine
211
233
212
234
def _describe (self ) -> Dict [str , Any ]:
213
- load_args = self ._load_args . copy ( )
214
- save_args = self ._save_args . copy ( )
235
+ load_args = copy . deepcopy ( self ._load_args )
236
+ save_args = copy . deepcopy ( self ._save_args )
215
237
del load_args ["table_name" ]
216
- del load_args ["con" ]
217
238
del save_args ["name" ]
218
- del save_args ["con" ]
219
239
return dict (
220
240
table_name = self ._load_args ["table_name" ],
221
241
load_args = load_args ,
222
242
save_args = save_args ,
223
243
)
224
244
225
245
def _load (self ) -> pd .DataFrame :
226
- try :
227
- return pd .read_sql_table (** self ._load_args )
228
- except ImportError as import_error :
229
- raise _get_missing_module_error (import_error ) from import_error
230
- except NoSuchModuleError as exc :
231
- raise _get_sql_alchemy_missing_error () from exc
246
+ engine = self .engines [self ._connection_str ] # type:ignore
247
+ return pd .read_sql_table (con = engine , ** self ._load_args )
232
248
233
249
def _save (self , data : pd .DataFrame ) -> None :
234
- try :
235
- data .to_sql (** self ._save_args )
236
- except ImportError as import_error :
237
- raise _get_missing_module_error (import_error ) from import_error
238
- except NoSuchModuleError as exc :
239
- raise _get_sql_alchemy_missing_error () from exc
250
+ engine = self .engines [self ._connection_str ] # type: ignore
251
+ data .to_sql (con = engine , ** self ._save_args )
240
252
241
253
def _exists (self ) -> bool :
242
- eng = create_engine ( self ._load_args [ "con" ])
254
+ eng = self .engines [ self . _connection_str ] # type: ignore
243
255
schema = self ._load_args .get ("schema" , None )
244
256
exists = self ._load_args ["table_name" ] in eng .table_names (schema )
245
- eng .dispose ()
246
257
return exists
247
258
248
259
@@ -299,6 +310,10 @@ class SQLQueryDataSet(AbstractDataSet):
299
310
300
311
"""
301
312
313
+ # using Any because of Sphinx but it should be
314
+ # sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
315
+ engines : Dict [str , Any ] = {}
316
+
302
317
def __init__ ( # pylint: disable=too-many-arguments
303
318
self ,
304
319
sql : str = None ,
@@ -374,32 +389,45 @@ def __init__( # pylint: disable=too-many-arguments
374
389
self ._protocol = protocol
375
390
self ._fs = fsspec .filesystem (self ._protocol , ** _fs_credentials , ** _fs_args )
376
391
self ._filepath = path
377
- self ._load_args ["con" ] = credentials ["con" ]
392
+ self ._connection_str = credentials ["con" ]
393
+ self .create_connection (self ._connection_str )
394
+
395
+ @classmethod
396
+ def create_connection (cls , connection_str : str ) -> None :
397
+ """Given a connection string, create singleton connection
398
+ to be used across all instances of `SQLQueryDataSet` that
399
+ need to connect to the same source.
400
+ """
401
+ if connection_str in cls .engines :
402
+ return
403
+
404
+ try :
405
+ engine = create_engine (connection_str )
406
+ except ImportError as import_error :
407
+ raise _get_missing_module_error (import_error ) from import_error
408
+ except NoSuchModuleError as exc :
409
+ raise _get_sql_alchemy_missing_error () from exc
410
+
411
+ cls .engines [connection_str ] = engine
378
412
379
413
def _describe (self ) -> Dict [str , Any ]:
380
414
load_args = copy .deepcopy (self ._load_args )
381
- desc = {}
382
- desc ["sql" ] = str (load_args .pop ("sql" , None ))
383
- desc ["filepath" ] = str (self ._filepath )
384
- del load_args ["con" ]
385
- desc ["load_args" ] = str (load_args )
386
-
387
- return desc
415
+ return dict (
416
+ sql = str (load_args .pop ("sql" , None )),
417
+ filepath = str (self ._filepath ),
418
+ load_args = str (load_args ),
419
+ )
388
420
389
421
def _load (self ) -> pd .DataFrame :
390
422
load_args = copy .deepcopy (self ._load_args )
423
+ engine = self .engines [self ._connection_str ] # type: ignore
391
424
392
425
if self ._filepath :
393
426
load_path = get_filepath_str (PurePosixPath (self ._filepath ), self ._protocol )
394
427
with self ._fs .open (load_path , mode = "r" ) as fs_file :
395
428
load_args ["sql" ] = fs_file .read ()
396
429
397
- try :
398
- return pd .read_sql_query (** load_args )
399
- except ImportError as import_error :
400
- raise _get_missing_module_error (import_error ) from import_error
401
- except NoSuchModuleError as exc :
402
- raise _get_sql_alchemy_missing_error () from exc
430
+ return pd .read_sql_query (con = engine , ** load_args )
403
431
404
432
def _save (self , data : pd .DataFrame ) -> None :
405
433
raise DataSetError ("`save` is not supported on SQLQueryDataSet" )
0 commit comments