11
11
logger = logging .getLogger (__name__ )
12
12
13
13
14
- def upscale_without_tiling (model , img : Image .Image ):
15
- img = np .array (img )
16
- img = img [:, :, ::- 1 ]
17
- img = np .ascontiguousarray (np .transpose (img , (2 , 0 , 1 ))) / 255
18
- img = torch .from_numpy (img ).float ()
19
-
14
+ def pil_image_to_torch_bgr (img : Image .Image ) -> torch .Tensor :
15
+ img = np .array (img .convert ("RGB" ))
16
+ img = img [:, :, ::- 1 ] # flip RGB to BGR
17
+ img = np .transpose (img , (2 , 0 , 1 )) # HWC to CHW
18
+ img = np .ascontiguousarray (img ) / 255 # Rescale to [0, 1]
19
+ return torch .from_numpy (img )
20
+
21
+
22
+ def torch_bgr_to_pil_image (tensor : torch .Tensor ) -> Image .Image :
23
+ if tensor .ndim == 4 :
24
+ # If we're given a tensor with a batch dimension, squeeze it out
25
+ # (but only if it's a batch of size 1).
26
+ if tensor .shape [0 ] != 1 :
27
+ raise ValueError (f"{ tensor .shape } does not describe a BCHW tensor" )
28
+ tensor = tensor .squeeze (0 )
29
+ assert tensor .ndim == 3 , f"{ tensor .shape } does not describe a CHW tensor"
30
+ # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
31
+ arr = tensor .float ().cpu ().clamp_ (0 , 1 ).numpy () # clamp
32
+ arr = 255.0 * np .moveaxis (arr , 0 , 2 ) # CHW to HWC, rescale
33
+ arr = arr .astype (np .uint8 )
34
+ arr = arr [:, :, ::- 1 ] # flip BGR to RGB
35
+ return Image .fromarray (arr , "RGB" )
36
+
37
+
38
+ def upscale_pil_patch (model , img : Image .Image ) -> Image .Image :
39
+ """
40
+ Upscale a given PIL image using the given model.
41
+ """
20
42
param = torch_utils .get_param (model )
21
- img = img .unsqueeze (0 ).to (device = param .device , dtype = param .dtype )
22
43
23
44
with torch .no_grad ():
24
- output = model (img )
25
-
26
- output = output .squeeze ().float ().cpu ().clamp_ (0 , 1 ).numpy ()
27
- output = 255. * np .moveaxis (output , 0 , 2 )
28
- output = output .astype (np .uint8 )
29
- output = output [:, :, ::- 1 ]
30
- return Image .fromarray (output , 'RGB' )
45
+ tensor = pil_image_to_torch_bgr (img ).unsqueeze (0 ) # add batch dimension
46
+ tensor = tensor .to (device = param .device , dtype = param .dtype )
47
+ return torch_bgr_to_pil_image (model (tensor ))
31
48
32
49
33
50
def upscale_with_model (
@@ -40,7 +57,7 @@ def upscale_with_model(
40
57
) -> Image .Image :
41
58
if tile_size <= 0 :
42
59
logger .debug ("Upscaling %s without tiling" , img )
43
- output = upscale_without_tiling (model , img )
60
+ output = upscale_pil_patch (model , img )
44
61
logger .debug ("=> %s" , output )
45
62
return output
46
63
@@ -52,7 +69,7 @@ def upscale_with_model(
52
69
newrow = []
53
70
for x , w , tile in row :
54
71
logger .debug ("Tile (%d, %d) %s..." , x , y , tile )
55
- output = upscale_without_tiling (model , tile )
72
+ output = upscale_pil_patch (model , tile )
56
73
scale_factor = output .width // tile .width
57
74
logger .debug ("=> %s (scale factor %s)" , output , scale_factor )
58
75
newrow .append ([x * scale_factor , w * scale_factor , output ])
@@ -71,19 +88,22 @@ def upscale_with_model(
71
88
72
89
73
90
def tiled_upscale_2 (
74
- img ,
91
+ img : torch . Tensor ,
75
92
model ,
76
93
* ,
77
94
tile_size : int ,
78
95
tile_overlap : int ,
79
96
scale : int ,
80
- device ,
81
97
desc = "Tiled upscale" ,
82
98
):
83
99
# Alternative implementation of `upscale_with_model` originally used by
84
100
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
85
101
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
86
102
# Pillow space without weighting.
103
+
104
+ # Grab the device the model is on, and use it.
105
+ device = torch_utils .get_param (model ).device
106
+
87
107
b , c , h , w = img .size ()
88
108
tile_size = min (tile_size , h , w )
89
109
@@ -100,7 +120,8 @@ def tiled_upscale_2(
100
120
h * scale ,
101
121
w * scale ,
102
122
device = device ,
103
- ).type_as (img )
123
+ dtype = img .dtype ,
124
+ )
104
125
weights = torch .zeros_like (result )
105
126
logger .debug ("Upscaling %s to %s with tiles" , img .shape , result .shape )
106
127
with tqdm .tqdm (total = len (h_idx_list ) * len (w_idx_list ), desc = desc ) as pbar :
@@ -112,11 +133,13 @@ def tiled_upscale_2(
112
133
if shared .state .interrupted or shared .state .skipped :
113
134
break
114
135
136
+ # Only move this patch to the device if it's not already there.
115
137
in_patch = img [
116
138
...,
117
139
h_idx : h_idx + tile_size ,
118
140
w_idx : w_idx + tile_size ,
119
- ]
141
+ ].to (device = device )
142
+
120
143
out_patch = model (in_patch )
121
144
122
145
result [
@@ -138,3 +161,29 @@ def tiled_upscale_2(
138
161
output = result .div_ (weights )
139
162
140
163
return output
164
+
165
+
166
+ def upscale_2 (
167
+ img : Image .Image ,
168
+ model ,
169
+ * ,
170
+ tile_size : int ,
171
+ tile_overlap : int ,
172
+ scale : int ,
173
+ desc : str ,
174
+ ):
175
+ """
176
+ Convenience wrapper around `tiled_upscale_2` that handles PIL images.
177
+ """
178
+ tensor = pil_image_to_torch_bgr (img ).float ().unsqueeze (0 ) # add batch dimension
179
+
180
+ with torch .no_grad ():
181
+ output = tiled_upscale_2 (
182
+ tensor ,
183
+ model ,
184
+ tile_size = tile_size ,
185
+ tile_overlap = tile_overlap ,
186
+ scale = scale ,
187
+ desc = desc ,
188
+ )
189
+ return torch_bgr_to_pil_image (output )
0 commit comments