Skip to content

Commit 6d2bdc8

Browse files
committed
add use xpu option, *test=kunlun
1 parent fec42fd commit 6d2bdc8

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

train.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def parse_args():
120120
help='The option of train profiler. If profiler_options is not None, the train ' \
121121
'profiler is enabled. Refer to the paddleseg/utils/train_profiler.py for details.'
122122
)
123+
parser.add_argument(
124+
'--device',
125+
dest='device',
126+
help='Device place, which can be GPU, XPU, CPU',
127+
default='gpu',
128+
type=str)
123129

124130
return parser.parse_args()
125131

@@ -137,8 +143,12 @@ def main(args):
137143
['-' * 48])
138144
logger.info(info)
139145

140-
place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
141-
'GPUs used'] else 'cpu'
146+
if args.device == 'gpu' and env_info['Paddle compiled with cuda'] and and env_info['GPUs used']:
147+
place = 'gpu'
148+
elif args.device == 'xpu' and paddle.is_compiled_with_xpu():
149+
place = 'xpu'
150+
else:
151+
place = 'cpu'
142152

143153
paddle.set_device(place)
144154
if not args.cfg:

0 commit comments

Comments
 (0)