17
17
import os
18
18
import re
19
19
import sys
20
+ import threading
20
21
import warnings
22
+ from os .path import expandvars
21
23
from urllib .parse import (
22
24
parse_qs ,
23
25
ParseResult ,
40
42
Openable = (str , os .PathLike )
41
43
logger = logging .getLogger (__name__ )
42
44
45
+ # Variables which values should not be expanded
46
+ NOT_EXPANDED = 'DJANGO_SECRET_KEY' , 'CACHE_URL'
47
+
43
48
44
49
def _cast (value ):
45
50
# Safely evaluate an expression node or a string containing a Python
@@ -189,7 +194,11 @@ class Env:
189
194
for s in ('' , 's' )]
190
195
CLOUDSQL = 'cloudsql'
191
196
197
+ VAR = re .compile (r'(?<!\\)\$\{?(?P<name>[A-Z_][0-9A-Z_]*)}?' ,
198
+ re .IGNORECASE )
199
+
192
200
def __init__ (self , ** scheme ):
201
+ self ._local = threading .local ()
193
202
self .smart_cast = True
194
203
self .escape_proxy = False
195
204
self .prefix = ""
@@ -343,9 +352,13 @@ def path(self, var, default=NOTSET, **kwargs):
343
352
"""
344
353
return Path (self .get_value (var , default = default ), ** kwargs )
345
354
346
- def get_value (self , var , cast = None , default = NOTSET , parse_default = False ):
355
+ def get_value (self , var , cast = None , # pylint: disable=R0913
356
+ default = NOTSET , parse_default = False , add_prefix = True ):
347
357
"""Return value for given environment variable.
348
358
359
+ - Expand variables referenced as ``$VAR`` or ``${VAR}``.
360
+ - Detect infinite recursion in expansion (self-reference).
361
+
349
362
:param str var:
350
363
Name of variable.
351
364
:param collections.abc.Callable or None cast:
@@ -354,15 +367,33 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False):
354
367
If var not present in environ, return this instead.
355
368
:param bool parse_default:
356
369
Force to parse default.
370
+ :param bool add_prefix:
371
+ Whether to add prefix to variable name.
357
372
:returns: Value from environment or default (if set).
358
373
:rtype: typing.IO[typing.Any]
359
374
"""
360
-
375
+ var_name = f'{ self .prefix } { var } ' if add_prefix else var
376
+ if not hasattr (self ._local , 'vars' ):
377
+ self ._local .vars = set ()
378
+ if var_name in self ._local .vars :
379
+ error_msg = f"Environment variable '{ var_name } ' recursively " \
380
+ "references itself (eventually)"
381
+ raise ImproperlyConfigured (error_msg )
382
+
383
+ self ._local .vars .add (var_name )
384
+ try :
385
+ return self ._get_value (
386
+ var_name , cast = cast , default = default ,
387
+ parse_default = parse_default )
388
+ finally :
389
+ self ._local .vars .remove (var_name )
390
+
391
+ def _get_value (self , var_name , cast = None , default = NOTSET ,
392
+ parse_default = False ):
361
393
logger .debug (
362
394
"get '%s' casted as '%s' with default '%s'" ,
363
- var , cast , default )
395
+ var_name , cast , default )
364
396
365
- var_name = f'{ self .prefix } { var } '
366
397
if var_name in self .scheme :
367
398
var_info = self .scheme [var_name ]
368
399
@@ -388,26 +419,37 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False):
388
419
value = self .ENVIRON [var_name ]
389
420
except KeyError as exc :
390
421
if default is self .NOTSET :
391
- error_msg = f'Set the { var } environment variable'
422
+ error_msg = f'Set the { var_name } environment variable'
392
423
raise ImproperlyConfigured (error_msg ) from exc
393
424
394
425
value = default
395
426
427
+ # Expand variables
428
+ if isinstance (value , (bytes , str )) and var_name not in NOT_EXPANDED :
429
+ def repl (match_ ):
430
+ return self .get_value (
431
+ match_ .group ('name' ), cast = cast , default = default ,
432
+ parse_default = parse_default , add_prefix = False )
433
+
434
+ is_bytes = isinstance (value , bytes )
435
+ if is_bytes :
436
+ value = value .decode ('utf-8' )
437
+ value = self .VAR .sub (repl , value )
438
+ value = expandvars (value )
439
+ if is_bytes :
440
+ value = value .encode ('utf-8' )
441
+
396
442
# Resolve any proxied values
397
443
prefix = b'$' if isinstance (value , bytes ) else '$'
398
444
escape = rb'\$' if isinstance (value , bytes ) else r'\$'
399
- if hasattr (value , 'startswith' ) and value .startswith (prefix ):
400
- value = value .lstrip (prefix )
401
- value = self .get_value (value , cast = cast , default = default )
402
445
403
446
if self .escape_proxy and hasattr (value , 'replace' ):
404
447
value = value .replace (escape , prefix )
405
448
406
449
# Smart casting
407
- if self .smart_cast :
408
- if cast is None and default is not None and \
409
- not isinstance (default , NoValue ):
410
- cast = type (default )
450
+ if self .smart_cast and cast is None and default is not None \
451
+ and not isinstance (default , NoValue ):
452
+ cast = type (default )
411
453
412
454
value = None if default is None and value == '' else value
413
455
0 commit comments