Skip to content

Commit 58c1954

Browse files
committed
Add FP32 fallback support on sd_vae_approx
This tries to execute interpolate with FP32 if it failed. Background is that on some environment such as Mx chip MacOS devices, we get error as follows: ``` "torch/nn/functional.py", line 3931, in interpolate return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half' ``` In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it. Note that the submodule may require additional modifications. The following is the example modification on the other submodule. ```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py class Upsample(nn.Module): ..snip.. def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) else: try: x = F.interpolate(x, scale_factor=2, mode="nearest") except: x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype) if self.use_conv: x = self.conv(x) return x ..snip.. ``` You can see the FP32 fallback execution as same as sd_vae_approx.py.
1 parent 5f36f6a commit 58c1954

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

modules/sd_vae_approx.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ def __init__(self):
2121

2222
def forward(self, x):
2323
extra = 11
24-
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
24+
try:
25+
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
26+
except RuntimeError as e:
27+
if "not implemented for" in str(e) and "Half" in str(e):
28+
x = nn.functional.interpolate(x.to(torch.float32), (x.shape[2] * 2, x.shape[3] * 2)).to(x.dtype)
29+
else:
30+
print(f"An unexpected RuntimeError occurred: {str(e)}")
2531
x = nn.functional.pad(x, (extra, extra, extra, extra))
2632

2733
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:

0 commit comments

Comments
 (0)