Skip to content

Commit 4bc70d8

Browse files
committed
Save infer model when saving checkpoint
1 parent c63f072 commit 4bc70d8

File tree

5 files changed

+40
-8
lines changed

5 files changed

+40
-8
lines changed

contrib/HumanSeg/models/humanseg.py

+11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import yaml
2828
import shutil
2929
import paddleslim as slim
30+
import paddle
3031

3132
import utils
3233
import utils.logging as logging
@@ -37,6 +38,15 @@
3738
import transforms as T
3839

3940

41+
def save_infer_program(test_program, ckpt_dir):
42+
_test_program = test_program.clone()
43+
_test_program.desc.flush()
44+
_test_program.desc._set_version()
45+
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
46+
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
47+
f.write(_test_program.desc.serialize_to_string())
48+
49+
4050
def dict2str(dict_input):
4151
out = ''
4252
for k, v in dict_input.items():
@@ -244,6 +254,7 @@ def save_model(self, save_dir):
244254

245255
if self.status == 'Normal':
246256
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
257+
save_infer_program(self.test_prog, save_dir)
247258
model_info['status'] = 'Normal'
248259
elif self.status == 'Quant':
249260
fluid.save(self.test_prog, osp.join(save_dir, 'model'))

contrib/RemoteSensing/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,3 @@
2121
from utils.utils import get_environ_info
2222

2323
env_info = get_environ_info()
24-
25-
log_level = 2

contrib/RemoteSensing/models/base.py

+11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030
from utils.metrics import ConfusionMatrix
3131
import transforms.transforms as T
3232
import utils
33+
import paddle
34+
35+
36+
def save_infer_program(test_program, ckpt_dir):
37+
_test_program = test_program.clone()
38+
_test_program.desc.flush()
39+
_test_program.desc._set_version()
40+
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
41+
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
42+
f.write(_test_program.desc.serialize_to_string())
3343

3444

3545
def dict2str(dict_input):
@@ -238,6 +248,7 @@ def save_model(self, save_dir):
238248

239249
if self.status == 'Normal':
240250
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
251+
save_infer_program(self.test_prog, save_dir)
241252

242253
model_info['status'] = self.status
243254
with open(

contrib/RemoteSensing/utils/logging.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import time
1717
import os
1818
import sys
19-
import __init__
2019

2120
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
2221

@@ -25,10 +24,9 @@ def log(level=2, message=""):
2524
current_time = time.time()
2625
time_array = time.localtime(current_time)
2726
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
28-
if __init__.log_level >= level:
29-
print("{} [{}]\t{}".format(current_time, levels[level],
30-
message).encode("utf-8").decode("latin1"))
31-
sys.stdout.flush()
27+
print("{} [{}]\t{}".format(current_time, levels[level],
28+
message).encode("utf-8").decode("latin1"))
29+
sys.stdout.flush()
3230

3331

3432
def debug(message=""):

pdseg/train.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import random
2828
import shutil
2929

30+
import paddle
3031
import numpy as np
3132
import paddle.fluid as fluid
3233
from paddle.fluid import profiler
@@ -158,6 +159,15 @@ def load_checkpoint(exe, program):
158159
return begin_epoch
159160

160161

162+
def save_infer_program(test_program, ckpt_dir):
163+
_test_program = test_program.clone()
164+
_test_program.desc.flush()
165+
_test_program.desc._set_version()
166+
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
167+
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
168+
f.write(_test_program.desc.serialize_to_string())
169+
170+
161171
def update_best_model(ckpt_dir):
162172
best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')
163173
if os.path.exists(best_model_dir):
@@ -173,6 +183,7 @@ def print_info(*msg):
173183
def train(cfg):
174184
startup_prog = fluid.Program()
175185
train_prog = fluid.Program()
186+
test_prog = fluid.Program()
176187
if args.enable_ce:
177188
startup_prog.random_seed = 1000
178189
train_prog.random_seed = 1000
@@ -224,6 +235,7 @@ def data_generator():
224235

225236
data_loader, avg_loss, lr, pred, grts, masks = build_model(
226237
train_prog, startup_prog, phase=ModelPhase.TRAIN)
238+
build_model(test_prog, fluid.Program(), phase=ModelPhase.EVAL)
227239
data_loader.set_sample_generator(
228240
data_generator, batch_size=batch_size_per_dev, drop_last=drop_last)
229241

@@ -387,6 +399,7 @@ def data_generator():
387399
if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0
388400
or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0:
389401
ckpt_dir = save_checkpoint(train_prog, epoch)
402+
save_infer_program(test_prog, ckpt_dir)
390403

391404
if args.do_eval:
392405
print("Evaluation start")
@@ -419,7 +432,8 @@ def data_generator():
419432

420433
# save final model
421434
if cfg.TRAINER_ID == 0:
422-
save_checkpoint(train_prog, 'final')
435+
ckpt_dir = save_checkpoint(train_prog, 'final')
436+
save_infer_program(test_prog, ckpt_dir)
423437

424438

425439
def main(args):

0 commit comments

Comments
 (0)