Skip to content

Commit 30cac6a

Browse files
committed
Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right
1 parent 501993e commit 30cac6a

File tree

3 files changed

+87
-112
lines changed

3 files changed

+87
-112
lines changed

extensions-builtin/ScuNET/scripts/scunet_model.py

+10-38
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import sys
22

33
import PIL.Image
4-
import numpy as np
5-
import torch
64

75
import modules.upscaler
8-
from modules import devices, modelloader, script_callbacks, errors
9-
from modules.shared import opts
10-
from modules.upscaler_utils import tiled_upscale_2
6+
from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
117

128

139
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -40,46 +36,23 @@ def __init__(self, dirname):
4036
self.scalers = scalers
4137

4238
def do_upscale(self, img: PIL.Image.Image, selected_file):
43-
4439
devices.torch_gc()
45-
4640
try:
4741
model = self.load_model(selected_file)
4842
except Exception as e:
4943
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
5044
return img
5145

52-
device = devices.get_device_for('scunet')
53-
tile = opts.SCUNET_tile
54-
h, w = img.height, img.width
55-
np_img = np.array(img)
56-
np_img = np_img[:, :, ::-1] # RGB to BGR
57-
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
58-
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
59-
60-
if tile > h or tile > w:
61-
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
62-
_img[:, :, :h, :w] = torch_img # pad image
63-
torch_img = _img
64-
65-
with torch.no_grad():
66-
torch_output = tiled_upscale_2(
67-
torch_img,
68-
model,
69-
tile_size=opts.SCUNET_tile,
70-
tile_overlap=opts.SCUNET_tile_overlap,
71-
scale=1,
72-
device=devices.get_device_for('scunet'),
73-
desc="ScuNET tiles",
74-
).squeeze(0)
75-
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
76-
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
77-
del torch_img, torch_output
46+
img = upscaler_utils.upscale_2(
47+
img,
48+
model,
49+
tile_size=shared.opts.SCUNET_tile,
50+
tile_overlap=shared.opts.SCUNET_tile_overlap,
51+
scale=1, # ScuNET is a denoising model, not an upscaler
52+
desc='ScuNET',
53+
)
7854
devices.torch_gc()
79-
80-
output = np_output.transpose((1, 2, 0)) # CHW to HWC
81-
output = output[:, :, ::-1] # BGR to RGB
82-
return PIL.Image.fromarray((output * 255).astype(np.uint8))
55+
return img
8356

8457
def load_model(self, path: str):
8558
device = devices.get_device_for('scunet')
@@ -93,7 +66,6 @@ def load_model(self, path: str):
9366

9467
def on_ui_settings():
9568
import gradio as gr
96-
from modules import shared
9769

9870
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
9971
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))

extensions-builtin/SwinIR/scripts/swinir_model.py

+8-54
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import logging
22
import sys
33

4-
import numpy as np
5-
import torch
64
from PIL import Image
75

8-
from modules import modelloader, devices, script_callbacks, shared
9-
from modules.shared import opts
6+
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
107
from modules.upscaler import Upscaler, UpscalerData
11-
from modules.upscaler_utils import tiled_upscale_2
128

139
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
1410

@@ -36,9 +32,7 @@ def __init__(self, dirname):
3632
self.scalers = scalers
3733

3834
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
39-
current_config = (model_file, opts.SWIN_tile)
40-
41-
device = self._get_device()
35+
current_config = (model_file, shared.opts.SWIN_tile)
4236

4337
if self._cached_model_config == current_config:
4438
model = self._cached_model
@@ -51,12 +45,13 @@ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
5145
self._cached_model = model
5246
self._cached_model_config = current_config
5347

54-
img = upscale(
48+
img = upscaler_utils.upscale_2(
5549
img,
5650
model,
57-
tile=opts.SWIN_tile,
58-
tile_overlap=opts.SWIN_tile_overlap,
59-
device=device,
51+
tile_size=shared.opts.SWIN_tile,
52+
tile_overlap=shared.opts.SWIN_tile_overlap,
53+
scale=4, # TODO: This was hard-coded before too...
54+
desc="SwinIR",
6055
)
6156
devices.torch_gc()
6257
return img
@@ -77,7 +72,7 @@ def load_model(self, path, scale=4):
7772
dtype=devices.dtype,
7873
expected_architecture="SwinIR",
7974
)
80-
if getattr(opts, 'SWIN_torch_compile', False):
75+
if getattr(shared.opts, 'SWIN_torch_compile', False):
8176
try:
8277
model_descriptor.model.compile()
8378
except Exception:
@@ -88,47 +83,6 @@ def _get_device(self):
8883
return devices.get_device_for('swinir')
8984

9085

91-
def upscale(
92-
img,
93-
model,
94-
*,
95-
tile: int,
96-
tile_overlap: int,
97-
window_size=8,
98-
scale=4,
99-
device,
100-
):
101-
102-
img = np.array(img)
103-
img = img[:, :, ::-1]
104-
img = np.moveaxis(img, 2, 0) / 255
105-
img = torch.from_numpy(img).float()
106-
img = img.unsqueeze(0).to(device, dtype=devices.dtype)
107-
with torch.no_grad(), devices.autocast():
108-
_, _, h_old, w_old = img.size()
109-
h_pad = (h_old // window_size + 1) * window_size - h_old
110-
w_pad = (w_old // window_size + 1) * window_size - w_old
111-
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
112-
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
113-
output = tiled_upscale_2(
114-
img,
115-
model,
116-
tile_size=tile,
117-
tile_overlap=tile_overlap,
118-
scale=scale,
119-
device=device,
120-
desc="SwinIR tiles",
121-
)
122-
output = output[..., : h_old * scale, : w_old * scale]
123-
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
124-
if output.ndim == 3:
125-
output = np.transpose(
126-
output[[2, 1, 0], :, :], (1, 2, 0)
127-
) # CHW-RGB to HCW-BGR
128-
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
129-
return Image.fromarray(output, "RGB")
130-
131-
13286
def on_ui_settings():
13387
import gradio as gr
13488

modules/upscaler_utils.py

+69-20
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,40 @@
1111
logger = logging.getLogger(__name__)
1212

1313

14-
def upscale_without_tiling(model, img: Image.Image):
15-
img = np.array(img)
16-
img = img[:, :, ::-1]
17-
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
18-
img = torch.from_numpy(img).float()
19-
14+
def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
15+
img = np.array(img.convert("RGB"))
16+
img = img[:, :, ::-1] # flip RGB to BGR
17+
img = np.transpose(img, (2, 0, 1)) # HWC to CHW
18+
img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
19+
return torch.from_numpy(img)
20+
21+
22+
def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
23+
if tensor.ndim == 4:
24+
# If we're given a tensor with a batch dimension, squeeze it out
25+
# (but only if it's a batch of size 1).
26+
if tensor.shape[0] != 1:
27+
raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
28+
tensor = tensor.squeeze(0)
29+
assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
30+
# TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
31+
arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
32+
arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
33+
arr = arr.astype(np.uint8)
34+
arr = arr[:, :, ::-1] # flip BGR to RGB
35+
return Image.fromarray(arr, "RGB")
36+
37+
38+
def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
39+
"""
40+
Upscale a given PIL image using the given model.
41+
"""
2042
param = torch_utils.get_param(model)
21-
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
2243

2344
with torch.no_grad():
24-
output = model(img)
25-
26-
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
27-
output = 255. * np.moveaxis(output, 0, 2)
28-
output = output.astype(np.uint8)
29-
output = output[:, :, ::-1]
30-
return Image.fromarray(output, 'RGB')
45+
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
46+
tensor = tensor.to(device=param.device, dtype=param.dtype)
47+
return torch_bgr_to_pil_image(model(tensor))
3148

3249

3350
def upscale_with_model(
@@ -40,7 +57,7 @@ def upscale_with_model(
4057
) -> Image.Image:
4158
if tile_size <= 0:
4259
logger.debug("Upscaling %s without tiling", img)
43-
output = upscale_without_tiling(model, img)
60+
output = upscale_pil_patch(model, img)
4461
logger.debug("=> %s", output)
4562
return output
4663

@@ -52,7 +69,7 @@ def upscale_with_model(
5269
newrow = []
5370
for x, w, tile in row:
5471
logger.debug("Tile (%d, %d) %s...", x, y, tile)
55-
output = upscale_without_tiling(model, tile)
72+
output = upscale_pil_patch(model, tile)
5673
scale_factor = output.width // tile.width
5774
logger.debug("=> %s (scale factor %s)", output, scale_factor)
5875
newrow.append([x * scale_factor, w * scale_factor, output])
@@ -71,19 +88,22 @@ def upscale_with_model(
7188

7289

7390
def tiled_upscale_2(
74-
img,
91+
img: torch.Tensor,
7592
model,
7693
*,
7794
tile_size: int,
7895
tile_overlap: int,
7996
scale: int,
80-
device,
8197
desc="Tiled upscale",
8298
):
8399
# Alternative implementation of `upscale_with_model` originally used by
84100
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
85101
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
86102
# Pillow space without weighting.
103+
104+
# Grab the device the model is on, and use it.
105+
device = torch_utils.get_param(model).device
106+
87107
b, c, h, w = img.size()
88108
tile_size = min(tile_size, h, w)
89109

@@ -100,7 +120,8 @@ def tiled_upscale_2(
100120
h * scale,
101121
w * scale,
102122
device=device,
103-
).type_as(img)
123+
dtype=img.dtype,
124+
)
104125
weights = torch.zeros_like(result)
105126
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
106127
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
@@ -112,11 +133,13 @@ def tiled_upscale_2(
112133
if shared.state.interrupted or shared.state.skipped:
113134
break
114135

136+
# Only move this patch to the device if it's not already there.
115137
in_patch = img[
116138
...,
117139
h_idx : h_idx + tile_size,
118140
w_idx : w_idx + tile_size,
119-
]
141+
].to(device=device)
142+
120143
out_patch = model(in_patch)
121144

122145
result[
@@ -138,3 +161,29 @@ def tiled_upscale_2(
138161
output = result.div_(weights)
139162

140163
return output
164+
165+
166+
def upscale_2(
167+
img: Image.Image,
168+
model,
169+
*,
170+
tile_size: int,
171+
tile_overlap: int,
172+
scale: int,
173+
desc: str,
174+
):
175+
"""
176+
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
177+
"""
178+
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
179+
180+
with torch.no_grad():
181+
output = tiled_upscale_2(
182+
tensor,
183+
model,
184+
tile_size=tile_size,
185+
tile_overlap=tile_overlap,
186+
scale=scale,
187+
desc=desc,
188+
)
189+
return torch_bgr_to_pil_image(output)

0 commit comments

Comments
 (0)