Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] use Assign with fp64 in XPINNs #608

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions jointContribution/XPINNs/XPINN_2D_PoissonsEqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from matplotlib import gridspec
from matplotlib import patches
from matplotlib import tri
from paddle import nn

import ppsci

Expand All @@ -19,7 +20,7 @@
paddle.seed(1234)


class XPINN(paddle.nn.Layer):
class XPINN(nn.Layer):
# Initialize the class
def __init__(self, layer_list):
super().__init__()
Expand Down Expand Up @@ -130,13 +131,13 @@ def initialize_nn(self, layers, name_prefix):
shape=[1, layers[l + 1]],
dtype="float64",
is_bias=True,
default_initializer=paddle.nn.initializer.Constant(0.0),
default_initializer=nn.initializer.Constant(0.0),
)
amplitude = self.create_parameter(
shape=[1],
dtype="float64",
is_bias=True,
default_initializer=paddle.nn.initializer.Constant(0.05),
default_initializer=nn.initializer.Constant(0.05),
)

self.add_parameter(name_prefix + "_w_" + str(l), weight)
Expand All @@ -153,8 +154,7 @@ def w_init(self, size):
xavier_stddev = np.sqrt(2 / (in_dim + out_dim))
param = paddle.empty(size, "float64")
param = ppsci.utils.initializer.trunc_normal_(param, 0.0, xavier_stddev)
# TODO: Truncated normal and assign support float64
return lambda p_ten, _: p_ten.set_value(param)
return nn.initializer.Assign(param)

def neural_net_tanh(self, x, weights, biases, amplitudes):
num_layers = len(weights) + 1
Expand Down Expand Up @@ -479,8 +479,7 @@ def predict(self, x_star1, x_star2, x_star3):
print("Error u_total: %e" % (error_u_total))

############################# Plotting ###############################
if not os.path.exists("./target"):
os.mkdir("./target")
os.makedirs("./target", exist_ok=True)
fig, ax = plotting.newfig(1.0, 1.1)
plt.plot(range(1, max_iter + 1, 20), mse_hist1, "r-", linewidth=1, label="Sub-Net1")
plt.plot(
Expand Down
18 changes: 10 additions & 8 deletions ppsci/optimizer/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class Linear(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)
>>> lr = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)()
"""

def __init__(
Expand Down Expand Up @@ -218,7 +218,7 @@ class ExponentialDecay(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3)
>>> lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3)()
"""

def __init__(
Expand Down Expand Up @@ -280,7 +280,7 @@ class Cosine(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Cosine(10, 2, 1e-3)
>>> lr = ppsci.optimizer.lr_scheduler.Cosine(10, 2, 1e-3)()
"""

def __init__(
Expand Down Expand Up @@ -345,7 +345,7 @@ class Step(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Step(10, 1, 1e-3, 2, 0.95)
>>> lr = ppsci.optimizer.lr_scheduler.Step(10, 1, 1e-3, 2, 0.95)()
"""

def __init__(
Expand Down Expand Up @@ -407,7 +407,9 @@ class Piecewise(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Piecewise(10, 1, [2, 4], (1e-3, 1e-4))
>>> lr = ppsci.optimizer.lr_scheduler.Piecewise(
... 10, 1, [2, 4], (1e-3, 1e-4, 1e-5)
... )()
"""

def __init__(
Expand Down Expand Up @@ -467,7 +469,7 @@ class MultiStepDecay(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.MultiStepDecay(10, 1, 1e-3, (4, 5))
>>> lr = ppsci.optimizer.lr_scheduler.MultiStepDecay(10, 1, 1e-3, (4, 5))()
"""

def __init__(
Expand Down Expand Up @@ -601,7 +603,7 @@ class CosineWarmRestarts(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.CosineWarmRestarts(20, 1, 1e-3, 14, 2)
>>> lr = ppsci.optimizer.lr_scheduler.CosineWarmRestarts(20, 1, 1e-3, 14, 2)()
"""

def __init__(
Expand Down Expand Up @@ -676,7 +678,7 @@ class OneCycleLR(LRBase):

Examples:
>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.OneCycleLR(1e-3, 100)
>>> lr = ppsci.optimizer.lr_scheduler.OneCycleLR(100, 1, 1e-3)()
"""

def __init__(
Expand Down