@@ -207,60 +207,57 @@ def __init__(
207
207
self ._load_args ["table_name" ] = table_name
208
208
self ._save_args ["name" ] = table_name
209
209
210
- self ._load_args [ "con" ] = self . _save_args [ "con" ] = credentials ["con" ]
211
- self .create_connection (self ._load_args [ "con" ] )
210
+ self ._connection_str = credentials ["con" ]
211
+ self .create_connection (self ._connection_str )
212
212
213
213
@classmethod
214
- def create_connection (cls , con ):
215
- """Create singleton connection to be used
216
- across all instances of `SQLTableDataSet`.
214
+ def create_connection (cls , connection_str : str ) -> None :
215
+ """Given a connection string, create singleton connection
216
+ to be used across all instances of `SQLTableDataSet` that
217
+ need to connect to the same source.
217
218
"""
218
- if hasattr (cls , "engine" ):
219
+ if connection_str in getattr (cls , "engines" , {} ):
219
220
return
220
221
221
- engine = create_engine (con )
222
- cls .engine = engine
222
+ engines = cls .engines if hasattr (cls , "engines" ) else {} # type:ignore
223
+
224
+ try :
225
+ engine = create_engine (connection_str )
226
+ except ImportError as import_error :
227
+ raise _get_missing_module_error (import_error ) from import_error
228
+ except NoSuchModuleError as exc :
229
+ raise _get_sql_alchemy_missing_error () from exc
230
+
231
+ engines [connection_str ] = engine
232
+ cls .engines = engines # type: ignore
223
233
224
234
def _describe (self ) -> Dict [str , Any ]:
225
- load_args = self ._load_args . copy ( )
226
- save_args = self ._save_args . copy ( )
235
+ load_args = copy . deepcopy ( self ._load_args )
236
+ save_args = copy . deepcopy ( self ._save_args )
227
237
del load_args ["table_name" ]
228
- del load_args ["con" ]
229
238
del save_args ["name" ]
230
- del save_args ["con" ]
231
239
return dict (
232
240
table_name = self ._load_args ["table_name" ],
233
241
load_args = load_args ,
234
242
save_args = save_args ,
235
243
)
236
244
237
245
def _load (self ) -> pd .DataFrame :
238
- load_args = copy .deepcopy (self ._load_args )
239
- load_args ["con" ] = self .engine # type: ignore
246
+ engine = self .engines .get (self ._connection_str ) # type:ignore
240
247
241
- try :
242
- return pd .read_sql_table (** load_args )
243
- except ImportError as import_error :
244
- raise _get_missing_module_error (import_error ) from import_error
245
- except NoSuchModuleError as exc :
246
- raise _get_sql_alchemy_missing_error () from exc
248
+ # TODO: handle engine = None
249
+ return pd .read_sql_table (con = engine , ** self ._load_args )
247
250
248
251
def _save (self , data : pd .DataFrame ) -> None :
249
- save_args = copy .deepcopy (self ._save_args )
250
- save_args ["con" ] = self .engine # type: ignore
252
+ engine = self .engines .get (self ._connection_str ) # type: ignore
251
253
252
- try :
253
- data .to_sql (** save_args )
254
- except ImportError as import_error :
255
- raise _get_missing_module_error (import_error ) from import_error
256
- except NoSuchModuleError as exc :
257
- raise _get_sql_alchemy_missing_error () from exc
254
+ # TODO: handle engine = None
255
+ data .to_sql (con = engine , ** self ._save_args )
258
256
259
257
def _exists (self ) -> bool :
260
- eng = self .engine # type: ignore
258
+ eng = self .engines [ self . _connection_str ] # type: ignore
261
259
schema = self ._load_args .get ("schema" , None )
262
260
exists = self ._load_args ["table_name" ] in eng .table_names (schema )
263
- # eng.dispose()
264
261
return exists
265
262
266
263
@@ -392,45 +389,48 @@ def __init__( # pylint: disable=too-many-arguments
392
389
self ._protocol = protocol
393
390
self ._fs = fsspec .filesystem (self ._protocol , ** _fs_credentials , ** _fs_args )
394
391
self ._filepath = path
395
- self ._load_args [ "con" ] = credentials ["con" ]
396
- self .create_connection (self ._load_args [ "con" ] )
392
+ self ._connection_str = credentials ["con" ]
393
+ self .create_connection (self ._connection_str )
397
394
398
395
@classmethod
399
- def create_connection (cls , con ):
400
- """Create singleton connection to be used
401
- across all instances of `SQLQueryDataSet`.
396
+ def create_connection (cls , connection_str : str ) -> None :
397
+ """Given a connection string, create singleton connection
398
+ to be used across all instances of `SQLQueryDataSet` that
399
+ need to connect to the same source.
402
400
"""
403
- if hasattr (cls , "engine" ):
401
+ if connection_str in getattr (cls , "engines" , {} ):
404
402
return
405
403
406
- engine = create_engine (con )
407
- cls .engine = engine
404
+ engines = cls .engines if hasattr (cls , "engines" ) else {} # type:ignore
405
+
406
+ try :
407
+ engine = create_engine (connection_str )
408
+ except ImportError as import_error :
409
+ raise _get_missing_module_error (import_error ) from import_error
410
+ except NoSuchModuleError as exc :
411
+ raise _get_sql_alchemy_missing_error () from exc
412
+
413
+ engines [connection_str ] = engine
414
+ cls .engines = engines # type: ignore
408
415
409
416
def _describe (self ) -> Dict [str , Any ]:
410
417
load_args = copy .deepcopy (self ._load_args )
411
- desc = {}
412
- desc ["sql" ] = str (load_args .pop ("sql" , None ))
413
- desc ["filepath" ] = str (self ._filepath )
414
- del load_args ["con" ]
415
- desc ["load_args" ] = str (load_args )
416
-
417
- return desc
418
+ return dict (
419
+ sql = str (load_args .pop ("sql" , None )),
420
+ filepath = str (self ._filepath ),
421
+ load_args = str (load_args ),
422
+ )
418
423
419
424
def _load (self ) -> pd .DataFrame :
420
425
load_args = copy .deepcopy (self ._load_args )
421
- load_args [ "con" ] = self .engine # type: ignore
426
+ engine = self .engines [ self . _connection_str ] # type: ignore
422
427
423
428
if self ._filepath :
424
429
load_path = get_filepath_str (PurePosixPath (self ._filepath ), self ._protocol )
425
430
with self ._fs .open (load_path , mode = "r" ) as fs_file :
426
431
load_args ["sql" ] = fs_file .read ()
427
432
428
- try :
429
- return pd .read_sql_query (** load_args )
430
- except ImportError as import_error :
431
- raise _get_missing_module_error (import_error ) from import_error
432
- except NoSuchModuleError as exc :
433
- raise _get_sql_alchemy_missing_error () from exc
433
+ return pd .read_sql_query (con = engine , ** load_args )
434
434
435
435
def _save (self , data : pd .DataFrame ) -> None :
436
436
raise DataSetError ("`save` is not supported on SQLQueryDataSet" )
0 commit comments