Skip to content

Commit 452ba2f

Browse files
Add allen cahn sota (PaddlePaddle#879)
* add allen cahn ntk * update code * add allen cahn ntk * update code * update code * update code * 修改配置 --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent e5cdf29 commit 452ba2f

File tree

8 files changed

+924
-6
lines changed

8 files changed

+924
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""
2+
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
3+
"""
4+
5+
from os import path as osp
6+
7+
import hydra
8+
import numpy as np
9+
import paddle
10+
import scipy.io as sio
11+
from matplotlib import pyplot as plt
12+
from omegaconf import DictConfig
13+
14+
import ppsci
15+
from ppsci.loss import mtl
16+
from ppsci.utils import misc
17+
18+
dtype = paddle.get_default_dtype()
19+
20+
21+
def plot(
22+
t_star: np.ndarray,
23+
x_star: np.ndarray,
24+
u_ref: np.ndarray,
25+
u_pred: np.ndarray,
26+
output_dir: str,
27+
):
28+
fig = plt.figure(figsize=(18, 5))
29+
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
30+
u_ref = u_ref.reshape([len(t_star), len(x_star)])
31+
32+
plt.subplot(1, 3, 1)
33+
plt.pcolor(TT, XX, u_ref, cmap="jet")
34+
plt.colorbar()
35+
plt.xlabel("t")
36+
plt.ylabel("x")
37+
plt.title("Exact")
38+
plt.tight_layout()
39+
40+
plt.subplot(1, 3, 2)
41+
plt.pcolor(TT, XX, u_pred, cmap="jet")
42+
plt.colorbar()
43+
plt.xlabel("t")
44+
plt.ylabel("x")
45+
plt.title("Predicted")
46+
plt.tight_layout()
47+
48+
plt.subplot(1, 3, 3)
49+
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
50+
plt.colorbar()
51+
plt.xlabel("t")
52+
plt.ylabel("x")
53+
plt.title("Absolute error")
54+
plt.tight_layout()
55+
56+
fig_path = osp.join(output_dir, "ac.png")
57+
print(f"Saving figure to {fig_path}")
58+
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
59+
plt.close()
60+
61+
62+
def train(cfg: DictConfig):
63+
# set model
64+
model = ppsci.arch.MLP(**cfg.MODEL)
65+
66+
# set equation
67+
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}
68+
69+
# set constraint
70+
data = sio.loadmat(cfg.DATA_PATH)
71+
u_ref = data["usol"].astype(dtype) # (nt, nx)
72+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
73+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
74+
75+
u0 = u_ref[0, :] # [nx, ]
76+
77+
t0 = t_star[0] # float
78+
t1 = t_star[-1] # float
79+
80+
x0 = x_star[0] # float
81+
x1 = x_star[-1] # float
82+
83+
def gen_input_batch():
84+
tx = np.random.uniform(
85+
[t0, x0],
86+
[t1, x1],
87+
(cfg.TRAIN.batch_size, 2),
88+
).astype(dtype)
89+
return {
90+
"t": np.sort(tx[:, 0:1], axis=0),
91+
"x": tx[:, 1:2],
92+
}
93+
94+
def gen_label_batch(input_batch):
95+
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}
96+
97+
pde_constraint = ppsci.constraint.SupervisedConstraint(
98+
{
99+
"dataset": {
100+
"name": "ContinuousNamedArrayDataset",
101+
"input": gen_input_batch,
102+
"label": gen_label_batch,
103+
},
104+
},
105+
output_expr=equation["AllenCahn"].equations,
106+
loss=ppsci.loss.CausalMSELoss(
107+
cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol
108+
),
109+
name="PDE",
110+
)
111+
112+
ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
113+
ic_label = {"u": u0.reshape([-1, 1])}
114+
ic = ppsci.constraint.SupervisedConstraint(
115+
{
116+
"dataset": {
117+
"name": "IterableNamedArrayDataset",
118+
"input": ic_input,
119+
"label": ic_label,
120+
},
121+
},
122+
output_expr={"u": lambda out: out["u"]},
123+
loss=ppsci.loss.MSELoss("mean"),
124+
name="IC",
125+
)
126+
# wrap constraints together
127+
constraint = {
128+
pde_constraint.name: pde_constraint,
129+
ic.name: ic,
130+
}
131+
132+
# set optimizer
133+
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
134+
**cfg.TRAIN.lr_scheduler
135+
)()
136+
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
137+
138+
# set validator
139+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
140+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
141+
eval_label = {"u": u_ref.reshape([-1, 1])}
142+
u_validator = ppsci.validate.SupervisedValidator(
143+
{
144+
"dataset": {
145+
"name": "NamedArrayDataset",
146+
"input": eval_data,
147+
"label": eval_label,
148+
},
149+
"batch_size": cfg.EVAL.batch_size,
150+
},
151+
ppsci.loss.MSELoss("mean"),
152+
{"u": lambda out: out["u"]},
153+
metric={"L2Rel": ppsci.metric.L2Rel()},
154+
name="u_validator",
155+
)
156+
validator = {u_validator.name: u_validator}
157+
158+
# initialize solver
159+
solver = ppsci.solver.Solver(
160+
model,
161+
constraint,
162+
cfg.output_dir,
163+
optimizer,
164+
epochs=cfg.TRAIN.epochs,
165+
iters_per_epoch=cfg.TRAIN.iters_per_epoch,
166+
save_freq=cfg.TRAIN.save_freq,
167+
log_freq=cfg.log_freq,
168+
eval_during_train=True,
169+
eval_freq=cfg.TRAIN.eval_freq,
170+
equation=equation,
171+
validator=validator,
172+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
173+
checkpoint_path=cfg.TRAIN.checkpoint_path,
174+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
175+
loss_aggregator=mtl.NTK(
176+
model,
177+
len(constraint),
178+
cfg.TRAIN.ntk.update_freq,
179+
),
180+
cfg=cfg,
181+
)
182+
# train model
183+
solver.train()
184+
# evaluate after finished training
185+
solver.eval()
186+
# visualize prediction after finished training
187+
u_pred = solver.predict(
188+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
189+
)["u"]
190+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
191+
192+
# plot
193+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
194+
195+
196+
def evaluate(cfg: DictConfig):
197+
# set model
198+
model = ppsci.arch.MLP(**cfg.MODEL)
199+
200+
data = sio.loadmat(cfg.DATA_PATH)
201+
u_ref = data["usol"].astype(dtype) # (nt, nx)
202+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
203+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
204+
205+
# set validator
206+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
207+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
208+
eval_label = {"u": u_ref.reshape([-1, 1])}
209+
u_validator = ppsci.validate.SupervisedValidator(
210+
{
211+
"dataset": {
212+
"name": "NamedArrayDataset",
213+
"input": eval_data,
214+
"label": eval_label,
215+
},
216+
"batch_size": cfg.EVAL.batch_size,
217+
},
218+
ppsci.loss.MSELoss("mean"),
219+
{"u": lambda out: out["u"]},
220+
metric={"L2Rel": ppsci.metric.L2Rel()},
221+
name="u_validator",
222+
)
223+
validator = {u_validator.name: u_validator}
224+
225+
# initialize solver
226+
solver = ppsci.solver.Solver(
227+
model,
228+
output_dir=cfg.output_dir,
229+
log_freq=cfg.log_freq,
230+
validator=validator,
231+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
232+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
233+
)
234+
235+
# evaluate after finished training
236+
solver.eval()
237+
# visualize prediction after finished training
238+
u_pred = solver.predict(
239+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
240+
)["u"]
241+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
242+
243+
# plot
244+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
245+
246+
247+
def export(cfg: DictConfig):
248+
# set model
249+
model = ppsci.arch.MLP(**cfg.MODEL)
250+
251+
# initialize solver
252+
solver = ppsci.solver.Solver(
253+
model,
254+
pretrained_model_path=cfg.INFER.pretrained_model_path,
255+
)
256+
# export model
257+
from paddle.static import InputSpec
258+
259+
input_spec = [
260+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
261+
]
262+
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)
263+
264+
265+
def inference(cfg: DictConfig):
266+
from deploy.python_infer import pinn_predictor
267+
268+
predictor = pinn_predictor.PINNPredictor(cfg)
269+
data = sio.loadmat(cfg.DATA_PATH)
270+
u_ref = data["usol"].astype(dtype) # (nt, nx)
271+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
272+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
273+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
274+
275+
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
276+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
277+
output_dict = {
278+
store_key: output_dict[infer_key]
279+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
280+
}
281+
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
282+
# mapping data to cfg.INFER.output_keys
283+
284+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
285+
286+
287+
@hydra.main(
288+
version_base=None, config_path="./conf", config_name="allen_cahn_defalut_ntk.yaml"
289+
)
290+
def main(cfg: DictConfig):
291+
if cfg.mode == "train":
292+
train(cfg)
293+
elif cfg.mode == "eval":
294+
evaluate(cfg)
295+
elif cfg.mode == "export":
296+
export(cfg)
297+
elif cfg.mode == "infer":
298+
inference(cfg)
299+
else:
300+
raise ValueError(
301+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
302+
)
303+
304+
305+
if __name__ == "__main__":
306+
main()

0 commit comments

Comments
 (0)