Skip to content

Commit 241e554

Browse files
authored
Merge pull request #104 from tldr-group/refactor
Use inheritance for check_vertical_flux
2 parents 982efc5 + 7e04615 commit 241e554

File tree

1 file changed

+14
-37
lines changed

1 file changed

+14
-37
lines changed

taufactor/taufactor.py

+14-37
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def __init__(self, img, bc=(-0.5, 0.5), D_0=1, device=torch.device('cuda')):
4646
raise ValueError(
4747
f'Input image must only contain 0s and 1s. Your image must be segmented to use this tool. If your image has been segmented, ensure your labels are 0 for non-conductive and 1 for conductive phase. Your image has the following labels: {torch.unique(img).numpy()}. If you have more than one conductive phase, use the multi-phase solver.')
4848

49-
# calculate
50-
5149
# init conc
5250
self.conc = self.init_conc(img)
5351
# create nn map
@@ -172,7 +170,6 @@ def check_convergence(self, verbose, conv_crit):
172170
abs(self.top_bc - self.bot_bc)).cpu()
173171
self.tau = self.VF / \
174172
self.D_rel if self.D_rel != 0 else torch.tensor(torch.inf)
175-
176173

177174
if verbose == 'per_iter':
178175
print(
@@ -192,15 +189,16 @@ def check_convergence(self, verbose, conv_crit):
192189

193190
def calc_vertical_flux(self):
194191
'''Calculates the vertical flux through the volume'''
195-
vert_flux = self.conc[:, 1:-1, 1:-1, 1:-1] - \
196-
self.conc[:, :-2, 1:-1, 1:-1]
197-
vert_flux[self.conc[:, :-2, 1:-1, 1:-1] == 0] = 0
198-
vert_flux[self.conc[:, 1:-1, 1:-1, 1:-1] == 0] = 0
192+
# Indexing removes boundary layers (1 layer at every boundary)
193+
vert_flux = self.conc[:, 2:-1, 1:-1, 1:-1] - \
194+
self.conc[:, 1:-2, 1:-1, 1:-1]
195+
vert_flux[self.conc[:, 1:-2, 1:-1, 1:-1] == 0] = 0
196+
vert_flux[self.conc[:, 2:-1, 1:-1, 1:-1] == 0] = 0
199197
return vert_flux
200-
198+
201199
def check_vertical_flux(self, conv_crit):
202200
vert_flux = self.calc_vertical_flux()
203-
fl = torch.sum(vert_flux, (0, 2, 3))[1:-1]
201+
fl = torch.sum(vert_flux, (0, 2, 3))
204202
err = (fl.max() - fl.min())/(fl.max())
205203
if fl.min() == 0:
206204
return 'zero_flux', torch.mean(fl), err
@@ -292,21 +290,12 @@ def solve(self, iter_limit=5000, verbose=True, conv_crit=2*10**-2, D_0=1):
292290

293291
def calc_vertical_flux(self):
294292
'''Calculates the vertical flux through the volume'''
295-
vert_flux = abs(self.conc - torch.roll(self.conc, 1, 1))
296-
vert_flux[self.conc == 0] = 0
297-
vert_flux[torch.roll(self.conc, 1, 1) == 0] = 0
293+
# Indexing removes 2 boundary layers at top and bottom
294+
vert_flux = self.conc[:, 3:-2] - self.conc[:, 2:-3]
295+
vert_flux[self.conc[:, 3:-2] == 0] = 0
296+
vert_flux[self.conc[:, 2:-3] == 0] = 0
298297
return vert_flux
299298

300-
def check_vertical_flux(self, conv_crit):
301-
vert_flux = self.calc_vertical_flux()
302-
fl = torch.sum(vert_flux, (0, 2, 3))[3:-2]
303-
err = (fl.max() - fl.min())*2/(fl.max() + fl.min())
304-
if err < conv_crit or torch.isnan(err).item():
305-
return True, torch.mean(fl), err
306-
if fl.min() == 0:
307-
return 'zero_flux', torch.mean(fl), err
308-
return False, torch.mean(fl), err
309-
310299

311300
class MultiPhaseSolver(Solver):
312301
"""
@@ -348,8 +337,6 @@ def __init__(self, img, cond={1: 1}, bc=(-0.5, 0.5), device=torch.device('cuda:0
348337
# save original image in cuda
349338
img = torch.tensor(img, dtype=self.precision, device=self.device)
350339

351-
# calculate
352-
353340
# init conc
354341
self.conc = self.init_conc(img)
355342
# create nn map
@@ -498,21 +485,11 @@ def check_convergence(self, verbose, conv_crit):
498485

499486
def calc_vertical_flux(self):
500487
'''Calculates the vertical flux through the volume'''
501-
vert_flux = (self.conc[:, 1:-1, 1:-1, 1:-1] - self.conc[:,
502-
:-2, 1:-1, 1:-1]) * self.pre_factors[1][:, :-2, 1:-1, 1:-1]
503-
vert_flux[self.nn == torch.inf] = 0
488+
vert_flux = (self.conc[:, 2:-1, 1:-1, 1:-1] - self.conc[:,
489+
1:-2, 1:-1, 1:-1]) * self.pre_factors[1][:, 1:-2, 1:-1, 1:-1]
490+
vert_flux[self.nn[:,1:] == torch.inf] = 0
504491
return vert_flux
505492

506-
def check_vertical_flux(self, conv_crit):
507-
vert_flux = self.calc_vertical_flux()
508-
fl = torch.sum(vert_flux, (0, 2, 3))[2:-2]
509-
err = (fl.max() - fl.min())*2/(fl.max() + fl.min())
510-
if err < conv_crit or torch.isnan(err).item():
511-
return True, torch.mean(fl), err
512-
if fl.min() == 0:
513-
return 'zero_flux', torch.mean(fl), err
514-
return False, torch.mean(fl), err
515-
516493

517494
class ElectrodeSolver():
518495
"""

0 commit comments

Comments
 (0)