@@ -46,8 +46,6 @@ def __init__(self, img, bc=(-0.5, 0.5), D_0=1, device=torch.device('cuda')):
46
46
raise ValueError (
47
47
f'Input image must only contain 0s and 1s. Your image must be segmented to use this tool. If your image has been segmented, ensure your labels are 0 for non-conductive and 1 for conductive phase. Your image has the following labels: { torch .unique (img ).numpy ()} . If you have more than one conductive phase, use the multi-phase solver.' )
48
48
49
- # calculate
50
-
51
49
# init conc
52
50
self .conc = self .init_conc (img )
53
51
# create nn map
@@ -172,7 +170,6 @@ def check_convergence(self, verbose, conv_crit):
172
170
abs (self .top_bc - self .bot_bc )).cpu ()
173
171
self .tau = self .VF / \
174
172
self .D_rel if self .D_rel != 0 else torch .tensor (torch .inf )
175
-
176
173
177
174
if verbose == 'per_iter' :
178
175
print (
@@ -192,15 +189,16 @@ def check_convergence(self, verbose, conv_crit):
192
189
193
190
def calc_vertical_flux (self ):
194
191
'''Calculates the vertical flux through the volume'''
195
- vert_flux = self .conc [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ] - \
196
- self .conc [:, :- 2 , 1 :- 1 , 1 :- 1 ]
197
- vert_flux [self .conc [:, :- 2 , 1 :- 1 , 1 :- 1 ] == 0 ] = 0
198
- vert_flux [self .conc [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ] == 0 ] = 0
192
+ # Indexing removes boundary layers (1 layer at every boundary)
193
+ vert_flux = self .conc [:, 2 :- 1 , 1 :- 1 , 1 :- 1 ] - \
194
+ self .conc [:, 1 :- 2 , 1 :- 1 , 1 :- 1 ]
195
+ vert_flux [self .conc [:, 1 :- 2 , 1 :- 1 , 1 :- 1 ] == 0 ] = 0
196
+ vert_flux [self .conc [:, 2 :- 1 , 1 :- 1 , 1 :- 1 ] == 0 ] = 0
199
197
return vert_flux
200
-
198
+
201
199
def check_vertical_flux (self , conv_crit ):
202
200
vert_flux = self .calc_vertical_flux ()
203
- fl = torch .sum (vert_flux , (0 , 2 , 3 ))[ 1 : - 1 ]
201
+ fl = torch .sum (vert_flux , (0 , 2 , 3 ))
204
202
err = (fl .max () - fl .min ())/ (fl .max ())
205
203
if fl .min () == 0 :
206
204
return 'zero_flux' , torch .mean (fl ), err
@@ -292,21 +290,12 @@ def solve(self, iter_limit=5000, verbose=True, conv_crit=2*10**-2, D_0=1):
292
290
293
291
def calc_vertical_flux (self ):
294
292
'''Calculates the vertical flux through the volume'''
295
- vert_flux = abs (self .conc - torch .roll (self .conc , 1 , 1 ))
296
- vert_flux [self .conc == 0 ] = 0
297
- vert_flux [torch .roll (self .conc , 1 , 1 ) == 0 ] = 0
293
+ # Indexing removes 2 boundary layers at top and bottom
294
+ vert_flux = self .conc [:, 3 :- 2 ] - self .conc [:, 2 :- 3 ]
295
+ vert_flux [self .conc [:, 3 :- 2 ] == 0 ] = 0
296
+ vert_flux [self .conc [:, 2 :- 3 ] == 0 ] = 0
298
297
return vert_flux
299
298
300
- def check_vertical_flux (self , conv_crit ):
301
- vert_flux = self .calc_vertical_flux ()
302
- fl = torch .sum (vert_flux , (0 , 2 , 3 ))[3 :- 2 ]
303
- err = (fl .max () - fl .min ())* 2 / (fl .max () + fl .min ())
304
- if err < conv_crit or torch .isnan (err ).item ():
305
- return True , torch .mean (fl ), err
306
- if fl .min () == 0 :
307
- return 'zero_flux' , torch .mean (fl ), err
308
- return False , torch .mean (fl ), err
309
-
310
299
311
300
class MultiPhaseSolver (Solver ):
312
301
"""
@@ -348,8 +337,6 @@ def __init__(self, img, cond={1: 1}, bc=(-0.5, 0.5), device=torch.device('cuda:0
348
337
# save original image in cuda
349
338
img = torch .tensor (img , dtype = self .precision , device = self .device )
350
339
351
- # calculate
352
-
353
340
# init conc
354
341
self .conc = self .init_conc (img )
355
342
# create nn map
@@ -498,21 +485,11 @@ def check_convergence(self, verbose, conv_crit):
498
485
499
486
def calc_vertical_flux (self ):
500
487
'''Calculates the vertical flux through the volume'''
501
- vert_flux = (self .conc [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ] - self .conc [:,
502
- :- 2 , 1 :- 1 , 1 :- 1 ]) * self .pre_factors [1 ][:, :- 2 , 1 :- 1 , 1 :- 1 ]
503
- vert_flux [self .nn == torch .inf ] = 0
488
+ vert_flux = (self .conc [:, 2 :- 1 , 1 :- 1 , 1 :- 1 ] - self .conc [:,
489
+ 1 :- 2 , 1 :- 1 , 1 :- 1 ]) * self .pre_factors [1 ][:, 1 :- 2 , 1 :- 1 , 1 :- 1 ]
490
+ vert_flux [self .nn [:, 1 :] == torch .inf ] = 0
504
491
return vert_flux
505
492
506
- def check_vertical_flux (self , conv_crit ):
507
- vert_flux = self .calc_vertical_flux ()
508
- fl = torch .sum (vert_flux , (0 , 2 , 3 ))[2 :- 2 ]
509
- err = (fl .max () - fl .min ())* 2 / (fl .max () + fl .min ())
510
- if err < conv_crit or torch .isnan (err ).item ():
511
- return True , torch .mean (fl ), err
512
- if fl .min () == 0 :
513
- return 'zero_flux' , torch .mean (fl ), err
514
- return False , torch .mean (fl ), err
515
-
516
493
517
494
class ElectrodeSolver ():
518
495
"""
0 commit comments