@@ -324,37 +324,43 @@ def test_extrapolation_warnings(self):
324
324
325
325
def test_sensitivities (self ):
326
326
327
- def exact_diff_a (v , a , b ):
328
- return np .array ([v ** 2 + 2 * a ])
327
+ def exact_diff_a (y , a , b ):
328
+ return np .array ([
329
+ [y [0 ]** 2 + 2 * a ],
330
+ [y [0 ]]
331
+ ])
329
332
330
- def exact_diff_b (v , a , b ):
331
- return np .array ([v ])
333
+ def exact_diff_b (y , a , b ):
334
+ return np .array ([[ y [ 0 ]], [ 0 ] ])
332
335
333
- for f in ['' , 'python' , 'casadi' , 'jax' ]:
336
+ for convert_to_format in ['' , 'python' , 'casadi' , 'jax' ]:
334
337
model = pybamm .BaseModel ()
335
338
v = pybamm .Variable ("v" )
339
+ u = pybamm .Variable ("u" )
336
340
a = pybamm .InputParameter ("a" )
337
341
b = pybamm .InputParameter ("b" )
338
342
model .rhs = {v : a * v ** 2 + b * v + a ** 2 }
339
- model .initial_conditions = {v : 1 }
340
- model .convert_to_format = f
341
- solver = pybamm .ScipySolver ()
343
+ model .algebraic = {u : a * v - u }
344
+ model .initial_conditions = {v : 1 , u : a * 1 }
345
+ model .convert_to_format = convert_to_format
346
+ solver = pybamm .CasadiSolver ()
342
347
solver .set_up (model , calculate_sensitivites = True ,
343
348
inputs = {'a' : 0 , 'b' : 0 })
344
349
all_inputs = []
345
350
for v_value in [0.1 , - 0.2 , 1.5 , 8.4 ]:
346
- for a_value in [0.12 , 1.5 ]:
347
- for b_value in [0.82 , 1.9 ]:
348
- y = np .array ([v_value ])
349
- t = 0
350
- inputs = {'a' : a_value , 'b' : b_value }
351
- all_inputs .append ((t , y , inputs ))
351
+ for u_value in [0.13 , - 0.23 , 1.3 , 13.4 ]:
352
+ for a_value in [0.12 , 1.5 ]:
353
+ for b_value in [0.82 , 1.9 ]:
354
+ y = np .array ([v_value , u_value ])
355
+ t = 0
356
+ inputs = {'a' : a_value , 'b' : b_value }
357
+ all_inputs .append ((t , y , inputs ))
352
358
for t , y , inputs in all_inputs :
353
- if f == 'casadi' :
359
+ if model . convert_to_format == 'casadi' :
354
360
use_inputs = casadi .vertcat (* [x for x in inputs .values ()])
355
361
else :
356
362
use_inputs = inputs
357
- if f == 'jax' :
363
+ if model . convert_to_format == 'jax' :
358
364
sens = model .sensitivities_eval (
359
365
t , y , use_inputs
360
366
)
0 commit comments