@@ -207,60 +207,53 @@ def __init__(
207
207
self ._load_args ["table_name" ] = table_name
208
208
self ._save_args ["name" ] = table_name
209
209
210
- self ._load_args [ "con" ] = self . _save_args [ "con" ] = credentials ["con" ]
211
- self .create_connection (self ._load_args [ "con" ] )
210
+ self ._connection_str = credentials ["con" ]
211
+ self .create_connection (self ._connection_str )
212
212
213
213
@classmethod
214
- def create_connection (cls , con ):
215
- """Create singleton connection to be used
216
- across all instances of `SQLTableDataSet`.
214
+ def create_connection (cls , connection_str : str ) -> None :
215
+ """Given a connection string, create singleton connection
216
+ to be used across all instances of `SQLTableDataSet` that
217
+ need to connect to the same source.
217
218
"""
218
- if hasattr (cls , "engine" ):
219
+ if connection_str in getattr (cls , "engines" , {} ):
219
220
return
220
221
221
- engine = create_engine (con )
222
- cls .engine = engine
222
+ engines = cls .engines if hasattr (cls , "engines" ) else {} # type:ignore
223
+
224
+ try :
225
+ engine = create_engine (connection_str )
226
+ except ImportError as import_error :
227
+ raise _get_missing_module_error (import_error ) from import_error
228
+ except NoSuchModuleError as exc :
229
+ raise _get_sql_alchemy_missing_error () from exc
230
+
231
+ engines [connection_str ] = engine
232
+ cls .engines = engines # type: ignore
223
233
224
234
def _describe (self ) -> Dict [str , Any ]:
225
- load_args = self ._load_args . copy ( )
226
- save_args = self ._save_args . copy ( )
235
+ load_args = copy . deepcopy ( self ._load_args )
236
+ save_args = copy . deepcopy ( self ._save_args )
227
237
del load_args ["table_name" ]
228
- del load_args ["con" ]
229
238
del save_args ["name" ]
230
- del save_args ["con" ]
231
239
return dict (
232
240
table_name = self ._load_args ["table_name" ],
233
241
load_args = load_args ,
234
242
save_args = save_args ,
235
243
)
236
244
237
245
def _load (self ) -> pd .DataFrame :
238
- load_args = copy .deepcopy (self ._load_args )
239
- load_args ["con" ] = self .engine # type: ignore
240
-
241
- try :
242
- return pd .read_sql_table (** load_args )
243
- except ImportError as import_error :
244
- raise _get_missing_module_error (import_error ) from import_error
245
- except NoSuchModuleError as exc :
246
- raise _get_sql_alchemy_missing_error () from exc
246
+ engine = self .engines .get (self ._connection_str ) # type:ignore
247
+ return pd .read_sql_table (con = engine , ** self ._load_args )
247
248
248
249
def _save (self , data : pd .DataFrame ) -> None :
249
- save_args = copy .deepcopy (self ._save_args )
250
- save_args ["con" ] = self .engine # type: ignore
251
-
252
- try :
253
- data .to_sql (** save_args )
254
- except ImportError as import_error :
255
- raise _get_missing_module_error (import_error ) from import_error
256
- except NoSuchModuleError as exc :
257
- raise _get_sql_alchemy_missing_error () from exc
250
+ engine = self .engines .get (self ._connection_str ) # type: ignore
251
+ data .to_sql (con = engine , ** self ._save_args )
258
252
259
253
def _exists (self ) -> bool :
260
- eng = self .engine # type: ignore
254
+ eng = self .engines [ self . _connection_str ] # type: ignore
261
255
schema = self ._load_args .get ("schema" , None )
262
256
exists = self ._load_args ["table_name" ] in eng .table_names (schema )
263
- # eng.dispose()
264
257
return exists
265
258
266
259
@@ -392,45 +385,48 @@ def __init__( # pylint: disable=too-many-arguments
392
385
self ._protocol = protocol
393
386
self ._fs = fsspec .filesystem (self ._protocol , ** _fs_credentials , ** _fs_args )
394
387
self ._filepath = path
395
- self ._load_args [ "con" ] = credentials ["con" ]
396
- self .create_connection (self ._load_args [ "con" ] )
388
+ self ._connection_str = credentials ["con" ]
389
+ self .create_connection (self ._connection_str )
397
390
398
391
@classmethod
399
- def create_connection (cls , con ):
400
- """Create singleton connection to be used
401
- across all instances of `SQLQueryDataSet`.
392
+ def create_connection (cls , connection_str : str ) -> None :
393
+ """Given a connection string, create singleton connection
394
+ to be used across all instances of `SQLQueryDataSet` that
395
+ need to connect to the same source.
402
396
"""
403
- if hasattr (cls , "engine" ):
397
+ if connection_str in getattr (cls , "engines" , {} ):
404
398
return
405
399
406
- engine = create_engine (con )
407
- cls .engine = engine
400
+ engines = cls .engines if hasattr (cls , "engines" ) else {} # type:ignore
401
+
402
+ try :
403
+ engine = create_engine (connection_str )
404
+ except ImportError as import_error :
405
+ raise _get_missing_module_error (import_error ) from import_error
406
+ except NoSuchModuleError as exc :
407
+ raise _get_sql_alchemy_missing_error () from exc
408
+
409
+ engines [connection_str ] = engine
410
+ cls .engines = engines # type: ignore
408
411
409
412
def _describe (self ) -> Dict [str , Any ]:
410
413
load_args = copy .deepcopy (self ._load_args )
411
- desc = {}
412
- desc ["sql" ] = str (load_args .pop ("sql" , None ))
413
- desc ["filepath" ] = str (self ._filepath )
414
- del load_args ["con" ]
415
- desc ["load_args" ] = str (load_args )
416
-
417
- return desc
414
+ return dict (
415
+ sql = str (load_args .pop ("sql" , None )),
416
+ filepath = str (self ._filepath ),
417
+ load_args = str (load_args ),
418
+ )
418
419
419
420
def _load (self ) -> pd .DataFrame :
420
421
load_args = copy .deepcopy (self ._load_args )
421
- load_args [ "con" ] = self .engine # type: ignore
422
+ engine = self .engines [ self . _connection_str ] # type: ignore
422
423
423
424
if self ._filepath :
424
425
load_path = get_filepath_str (PurePosixPath (self ._filepath ), self ._protocol )
425
426
with self ._fs .open (load_path , mode = "r" ) as fs_file :
426
427
load_args ["sql" ] = fs_file .read ()
427
428
428
- try :
429
- return pd .read_sql_query (** load_args )
430
- except ImportError as import_error :
431
- raise _get_missing_module_error (import_error ) from import_error
432
- except NoSuchModuleError as exc :
433
- raise _get_sql_alchemy_missing_error () from exc
429
+ return pd .read_sql_query (con = engine , ** load_args )
434
430
435
431
def _save (self , data : pd .DataFrame ) -> None :
436
432
raise DataSetError ("`save` is not supported on SQLQueryDataSet" )
0 commit comments