From bb90c0a8d34c634af29384a9699138a8d33347ce Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Mon, 8 Nov 2021 14:35:57 +0800 Subject: [PATCH] fix the error of saving image in infer.py --- deploy/python/infer.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 13e4d140a2..aa31293280 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -213,11 +213,10 @@ def _init_gpu_config(self): self.pred_cfg.set_trt_dynamic_shape_info( min_input_shape, max_input_shape, opt_input_shape) - def run(self, imgs): - if not isinstance(imgs, (list, tuple)): - imgs = [imgs] + def run(self, imgs_path): + if not isinstance(imgs_path, (list, tuple)): + imgs_path = [imgs_path] - num = len(imgs) input_names = self.predictor.get_input_names() input_handle = self.predictor.get_input_handle(input_names[0]) output_names = self.predictor.get_output_names() @@ -228,12 +227,13 @@ def run(self, imgs): if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) - for i in range(0, num, args.batch_size): + for i in range(0, len(imgs_path), args.batch_size): # warm up if i == 0 and args.benchmark: for j in range(5): data = np.array([ - self._preprocess(img) for img in imgs[0:args.batch_size] + self._preprocess(img) + for img in imgs_path[0:args.batch_size] ]) input_handle.reshape(data.shape) input_handle.copy_from_cpu(data) @@ -244,11 +244,12 @@ def run(self, imgs): # inference if args.benchmark: self.autolog.times.start() - data = np.array( - [self._preprocess(img) for img in imgs[i:i + args.batch_size]]) + data = np.array( + [self._preprocess(p) for p in imgs_path[i:i + args.batch_size]]) input_handle.reshape(data.shape) input_handle.copy_from_cpu(data) + if args.benchmark: self.autolog.times.stamp() @@ -256,13 +257,13 @@ def run(self, imgs): if args.benchmark: self.autolog.times.stamp() - results = output_handle.copy_to_cpu() + results = output_handle.copy_to_cpu() results = self._postprocess(results) + self._save_imgs(results, imgs_path[i:i + args.batch_size]) if args.benchmark: self.autolog.times.end(stamp=True) - self._save_imgs(results, imgs) logger.info("Finish") @@ -274,10 +275,10 @@ def _postprocess(self, results): results = np.argmax(results, axis=1) return results - def _save_imgs(self, results, imgs): + def _save_imgs(self, results, imgs_path): for i in range(results.shape[0]): result = get_pseudo_color_map(results[i]) - basename = os.path.basename(imgs[i]) + basename = os.path.basename(imgs_path[i]) basename, _ = os.path.splitext(basename) basename = f'{basename}.png' result.save(os.path.join(self.args.save_dir, basename))