9
9
10
10
import numbers
11
11
from platform import system
12
+
12
13
if system () != "Windows" :
13
14
import jax
14
15
15
16
from jax .config import config
17
+
16
18
config .update ("jax_enable_x64" , True )
17
19
18
20
@@ -95,30 +97,35 @@ def find_symbols(symbol, constant_symbols, variable_symbols, to_dense=False):
95
97
dummy_eval_left = symbol .children [0 ].evaluate_for_shape ()
96
98
dummy_eval_right = symbol .children [1 ].evaluate_for_shape ()
97
99
if not to_dense and scipy .sparse .issparse (dummy_eval_left ):
98
- symbol_str = "{0}.multiply({1})" \
99
- .format (children_vars [0 ], children_vars [1 ])
100
+ symbol_str = "{0}.multiply({1})" .format (
101
+ children_vars [0 ], children_vars [1 ]
102
+ )
100
103
elif not to_dense and scipy .sparse .issparse (dummy_eval_right ):
101
- symbol_str = "{1}.multiply({0})" \
102
- .format (children_vars [0 ], children_vars [1 ])
104
+ symbol_str = "{1}.multiply({0})" .format (
105
+ children_vars [0 ], children_vars [1 ]
106
+ )
103
107
else :
104
108
symbol_str = "{0} * {1}" .format (children_vars [0 ], children_vars [1 ])
105
109
elif isinstance (symbol , pybamm .Division ):
106
110
dummy_eval_left = symbol .children [0 ].evaluate_for_shape ()
107
111
if not to_dense and scipy .sparse .issparse (dummy_eval_left ):
108
- symbol_str = "{0}.multiply(1/{1})" \
109
- .format (children_vars [0 ], children_vars [1 ])
112
+ symbol_str = "{0}.multiply(1/{1})" .format (
113
+ children_vars [0 ], children_vars [1 ]
114
+ )
110
115
else :
111
116
symbol_str = "{0} / {1}" .format (children_vars [0 ], children_vars [1 ])
112
117
113
118
elif isinstance (symbol , pybamm .Inner ):
114
119
dummy_eval_left = symbol .children [0 ].evaluate_for_shape ()
115
120
dummy_eval_right = symbol .children [1 ].evaluate_for_shape ()
116
121
if not to_dense and scipy .sparse .issparse (dummy_eval_left ):
117
- symbol_str = "{0}.multiply({1})" \
118
- .format (children_vars [0 ], children_vars [1 ])
122
+ symbol_str = "{0}.multiply({1})" .format (
123
+ children_vars [0 ], children_vars [1 ]
124
+ )
119
125
elif not to_dense and scipy .sparse .issparse (dummy_eval_right ):
120
- symbol_str = "{1}.multiply({0})" \
121
- .format (children_vars [0 ], children_vars [1 ])
126
+ symbol_str = "{1}.multiply({0})" .format (
127
+ children_vars [0 ], children_vars [1 ]
128
+ )
122
129
else :
123
130
symbol_str = "{0} * {1}" .format (children_vars [0 ], children_vars [1 ])
124
131
@@ -294,18 +301,20 @@ def __init__(self, symbol):
294
301
# extract constants in generated function
295
302
for i , symbol_id in enumerate (constants .keys ()):
296
303
const_name = id_to_python_variable (symbol_id , True )
297
- python_str = ' {} = constants[{}]\n ' .format (const_name , i ) + python_str
304
+ python_str = " {} = constants[{}]\n " .format (const_name , i ) + python_str
298
305
299
306
# constants passed in as an ordered dict, convert to list
300
307
self ._constants = list (constants .values ())
301
308
302
309
# indent code
303
- python_str = ' ' + python_str
304
- python_str = python_str .replace (' \n ' , ' \n ' )
310
+ python_str = " " + python_str
311
+ python_str = python_str .replace (" \n " , " \n " )
305
312
306
313
# add function def to first line
307
- python_str = 'def evaluate(constants, t=None, y=None, ' \
308
- 'y_dot=None, inputs=None, known_evals=None):\n ' + python_str
314
+ python_str = (
315
+ "def evaluate(constants, t=None, y=None, "
316
+ "y_dot=None, inputs=None, known_evals=None):\n " + python_str
317
+ )
309
318
310
319
# calculate the final variable that will output the result of calling `evaluate`
311
320
# on `symbol`
@@ -315,21 +324,18 @@ def __init__(self, symbol):
315
324
316
325
# add return line
317
326
if symbol .is_constant () and isinstance (result_value , numbers .Number ):
318
- python_str = python_str + ' \n return ' + str (result_value )
327
+ python_str = python_str + " \n return " + str (result_value )
319
328
else :
320
- python_str = python_str + ' \n return ' + result_var
329
+ python_str = python_str + " \n return " + result_var
321
330
322
331
# store a copy of examine_jaxpr
323
- python_str = python_str + \
324
- '\n self._evaluate = evaluate'
332
+ python_str = python_str + "\n self._evaluate = evaluate"
325
333
326
334
self ._python_str = python_str
327
335
self ._symbol = symbol
328
336
329
337
# compile and run the generated python code,
330
- compiled_function = compile (
331
- python_str , result_var , "exec"
332
- )
338
+ compiled_function = compile (python_str , result_var , "exec" )
333
339
exec (compiled_function )
334
340
335
341
def evaluate (self , t = None , y = None , y_dot = None , inputs = None , known_evals = None ):
@@ -377,7 +383,7 @@ def __init__(self, symbol):
377
383
constants , python_str = pybamm .to_python (symbol , debug = False , to_dense = True )
378
384
379
385
# replace numpy function calls to jax numpy calls
380
- python_str = python_str .replace (' np.' , ' jax.numpy.' )
386
+ python_str = python_str .replace (" np." , " jax.numpy." )
381
387
382
388
# convert all numpy constants to device vectors
383
389
for symbol_id in constants :
@@ -387,18 +393,20 @@ def __init__(self, symbol):
387
393
# extract constants in generated function
388
394
for i , symbol_id in enumerate (constants .keys ()):
389
395
const_name = id_to_python_variable (symbol_id , True )
390
- python_str = ' {} = constants[{}]\n ' .format (const_name , i ) + python_str
396
+ python_str = " {} = constants[{}]\n " .format (const_name , i ) + python_str
391
397
392
398
# constants passed in as an ordered dict, convert to list
393
399
self ._constants = list (constants .values ())
394
400
395
401
# indent code
396
- python_str = ' ' + python_str
397
- python_str = python_str .replace (' \n ' , ' \n ' )
402
+ python_str = " " + python_str
403
+ python_str = python_str .replace (" \n " , " \n " )
398
404
399
405
# add function def to first line
400
- python_str = 'def evaluate_jax(constants, t=None, y=None, ' \
401
- 'y_dot=None, inputs=None, known_evals=None):\n ' + python_str
406
+ python_str = (
407
+ "def evaluate_jax(constants, t=None, y=None, "
408
+ "y_dot=None, inputs=None, known_evals=None):\n " + python_str
409
+ )
402
410
403
411
# calculate the final variable that will output the result of calling `evaluate`
404
412
# on `symbol`
@@ -408,18 +416,15 @@ def __init__(self, symbol):
408
416
409
417
# add return line
410
418
if symbol .is_constant () and isinstance (result_value , numbers .Number ):
411
- python_str = python_str + ' \n return ' + str (result_value )
419
+ python_str = python_str + " \n return " + str (result_value )
412
420
else :
413
- python_str = python_str + ' \n return ' + result_var
421
+ python_str = python_str + " \n return " + result_var
414
422
415
423
# store a copy of examine_jaxpr
416
- python_str = python_str + \
417
- '\n self._evaluate_jax = evaluate_jax'
424
+ python_str = python_str + "\n self._evaluate_jax = evaluate_jax"
418
425
419
426
# compile and run the generated python code,
420
- compiled_function = compile (
421
- python_str , result_var , "exec"
422
- )
427
+ compiled_function = compile (python_str , result_var , "exec" )
423
428
exec (compiled_function )
424
429
425
430
self ._jit_evaluate = jax .jit (self ._evaluate_jax , static_argnums = (0 , 4 , 5 ))
0 commit comments