Skip to content

Commit 88c1224

Browse files
committed
[DirectML] Fix samplers.
1 parent e2cbdab commit 88c1224

File tree

2 files changed

+121
-104
lines changed

2 files changed

+121
-104
lines changed

modules/dml/hijack/diffusers.py

+120-103
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from typing import Optional, Union, Tuple
12
import torch
23
import diffusers
34
import diffusers.utils.torch_utils
4-
from typing import Optional, Union, Tuple
55

66

77
def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, prev_timestep, model_output):
@@ -17,7 +17,7 @@ def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, pr
1717
# sample -> x_t
1818
# model_output -> e_θ(x_t, t)
1919
# prev_sample -> x_(t−δ)
20-
sample.__str__() # PNDM Sampling does not work without 'stringify'. (because it depends on PLMS)
20+
torch.dml.synchronize_tensor(sample) # DML synchronize
2121
alpha_prod_t = self.alphas_cumprod[timestep]
2222
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
2323
beta_prod_t = 1 - alpha_prod_t
@@ -53,51 +53,68 @@ def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, pr
5353

5454

5555
def UniPCMultistepScheduler_multistep_uni_p_bh_update(
56-
self: diffusers.UniPCMultistepScheduler,
56+
self,
5757
model_output: torch.FloatTensor,
58-
prev_timestep: int,
59-
sample: torch.FloatTensor,
60-
order: int,
58+
*args,
59+
sample: torch.FloatTensor = None,
60+
order: int = None,
61+
**kwargs,
6162
) -> torch.FloatTensor:
6263
"""
6364
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
6465
6566
Args:
6667
model_output (`torch.FloatTensor`):
67-
direct outputs from learned diffusion model at the current timestep.
68-
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
68+
The direct output from the learned diffusion model at the current timestep.
69+
prev_timestep (`int`):
70+
The previous discrete timestep in the diffusion chain.
6971
sample (`torch.FloatTensor`):
70-
current instance of sample being created by diffusion process.
71-
order (`int`): the order of UniP at this step, also the p in UniPC-p.
72+
A current instance of a sample created by the diffusion process.
73+
order (`int`):
74+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
7275
7376
Returns:
74-
`torch.FloatTensor`: the sample tensor at the previous timestep.
77+
`torch.FloatTensor`:
78+
The sample tensor at the previous timestep.
7579
"""
76-
timestep_list = self.timestep_list
80+
if sample is None:
81+
if len(args) > 1:
82+
sample = args[1]
83+
else:
84+
raise ValueError(" missing `sample` as a required keyward argument")
85+
if order is None:
86+
if len(args) > 2:
87+
order = args[2]
88+
else:
89+
raise ValueError(" missing `order` as a required keyward argument")
7790
model_output_list = self.model_outputs
7891

79-
s0, t = self.timestep_list[-1], prev_timestep
92+
s0 = self.timestep_list[-1]
8093
m0 = model_output_list[-1]
8194
x = sample
8295

8396
if self.solver_p:
8497
x_t = self.solver_p.step(model_output, s0, x).prev_sample
8598
return x_t
8699

87-
sample.__str__() # UniPC Sampling does not work without 'stringify'.
88-
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
89-
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
90-
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
100+
torch.dml.synchronize_tensor(sample) # DML synchronize
101+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
102+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
103+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
104+
105+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
106+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
91107

92108
h = lambda_t - lambda_s0
93109
device = sample.device
94110

95111
rks = []
96112
D1s = []
97113
for i in range(1, order):
98-
si = timestep_list[-(i + 1)]
114+
si = self.step_index - i
99115
mi = model_output_list[-(i + 1)]
100-
lambda_si = self.lambda_t[si]
116+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
117+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
101118
rk = (lambda_si - lambda_s0) / h
102119
rks.append(rk)
103120
D1s.append((mi - m0) / rk)
@@ -143,14 +160,14 @@ def UniPCMultistepScheduler_multistep_uni_p_bh_update(
143160
if self.predict_x0:
144161
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
145162
if D1s is not None:
146-
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
163+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
147164
else:
148165
pred_res = 0
149166
x_t = x_t_ - alpha_t * B_h * pred_res
150167
else:
151168
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
152169
if D1s is not None:
153-
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
170+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
154171
else:
155172
pred_res = 0
156173
x_t = x_t_ - sigma_t * B_h * pred_res
@@ -170,91 +187,91 @@ def LCMScheduler_step(
170187
generator: Optional[torch.Generator] = None,
171188
return_dict: bool = True,
172189
) -> Union[diffusers.schedulers.scheduling_lcm.LCMSchedulerOutput, Tuple]:
173-
"""
174-
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
175-
process from the learned model outputs (most often the predicted noise).
176-
177-
Args:
178-
model_output (`torch.FloatTensor`):
179-
The direct output from learned diffusion model.
180-
timestep (`float`):
181-
The current discrete timestep in the diffusion chain.
182-
sample (`torch.FloatTensor`):
183-
A current instance of a sample created by the diffusion process.
184-
generator (`torch.Generator`, *optional*):
185-
A random number generator.
186-
return_dict (`bool`, *optional*, defaults to `True`):
187-
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
188-
Returns:
189-
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
190-
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
191-
tuple is returned where the first element is the sample tensor.
192-
"""
193-
if self.num_inference_steps is None:
194-
raise ValueError(
195-
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
196-
)
197-
198-
if self.step_index is None:
199-
self._init_step_index(timestep)
200-
201-
# 1. get previous step value
202-
prev_step_index = self.step_index + 1
203-
if prev_step_index < len(self.timesteps):
204-
prev_timestep = self.timesteps[prev_step_index]
205-
else:
206-
prev_timestep = timestep
207-
208-
# 2. compute alphas, betas
209-
sample.__str__()
210-
alpha_prod_t = self.alphas_cumprod[timestep]
211-
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
212-
213-
beta_prod_t = 1 - alpha_prod_t
214-
beta_prod_t_prev = 1 - alpha_prod_t_prev
215-
216-
# 3. Get scalings for boundary conditions
217-
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
218-
219-
# 4. Compute the predicted original sample x_0 based on the model parameterization
220-
if self.config.prediction_type == "epsilon": # noise-prediction
221-
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
222-
elif self.config.prediction_type == "sample": # x-prediction
223-
predicted_original_sample = model_output
224-
elif self.config.prediction_type == "v_prediction": # v-prediction
225-
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
226-
else:
227-
raise ValueError(
228-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
229-
" `v_prediction` for `LCMScheduler`."
230-
)
231-
232-
# 5. Clip or threshold "predicted x_0"
233-
if self.config.thresholding:
234-
predicted_original_sample = self._threshold_sample(predicted_original_sample)
235-
elif self.config.clip_sample:
236-
predicted_original_sample = predicted_original_sample.clamp(
237-
-self.config.clip_sample_range, self.config.clip_sample_range
238-
)
239-
240-
# 6. Denoise model output using boundary conditions
241-
denoised = c_out * predicted_original_sample + c_skip * sample
242-
243-
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
244-
# Noise is not used for one-step sampling.
245-
if len(self.timesteps) > 1:
246-
noise = diffusers.utils.torch_utils.randn_tensor(model_output.shape, generator=generator, device=model_output.device)
247-
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
248-
else:
249-
prev_sample = denoised
190+
"""
191+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
192+
process from the learned model outputs (most often the predicted noise).
193+
194+
Args:
195+
model_output (`torch.FloatTensor`):
196+
The direct output from learned diffusion model.
197+
timestep (`float`):
198+
The current discrete timestep in the diffusion chain.
199+
sample (`torch.FloatTensor`):
200+
A current instance of a sample created by the diffusion process.
201+
generator (`torch.Generator`, *optional*):
202+
A random number generator.
203+
return_dict (`bool`, *optional*, defaults to `True`):
204+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
205+
Returns:
206+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
207+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
208+
tuple is returned where the first element is the sample tensor.
209+
"""
210+
if self.num_inference_steps is None:
211+
raise ValueError(
212+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
213+
)
214+
215+
if self.step_index is None:
216+
self._init_step_index(timestep)
217+
218+
# 1. get previous step value
219+
prev_step_index = self.step_index + 1
220+
if prev_step_index < len(self.timesteps):
221+
prev_timestep = self.timesteps[prev_step_index]
222+
else:
223+
prev_timestep = timestep
224+
225+
# 2. compute alphas, betas
226+
torch.dml.synchronize_tensor(sample) # DML synchronize
227+
alpha_prod_t = self.alphas_cumprod[timestep]
228+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
229+
230+
beta_prod_t = 1 - alpha_prod_t
231+
beta_prod_t_prev = 1 - alpha_prod_t_prev
232+
233+
# 3. Get scalings for boundary conditions
234+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
235+
236+
# 4. Compute the predicted original sample x_0 based on the model parameterization
237+
if self.config.prediction_type == "epsilon": # noise-prediction
238+
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
239+
elif self.config.prediction_type == "sample": # x-prediction
240+
predicted_original_sample = model_output
241+
elif self.config.prediction_type == "v_prediction": # v-prediction
242+
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
243+
else:
244+
raise ValueError(
245+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
246+
" `v_prediction` for `LCMScheduler`."
247+
)
248+
249+
# 5. Clip or threshold "predicted x_0"
250+
if self.config.thresholding:
251+
predicted_original_sample = self._threshold_sample(predicted_original_sample)
252+
elif self.config.clip_sample:
253+
predicted_original_sample = predicted_original_sample.clamp(
254+
-self.config.clip_sample_range, self.config.clip_sample_range
255+
)
256+
257+
# 6. Denoise model output using boundary conditions
258+
denoised = c_out * predicted_original_sample + c_skip * sample
259+
260+
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
261+
# Noise is not used for one-step sampling.
262+
if len(self.timesteps) > 1:
263+
noise = diffusers.utils.torch_utils.randn_tensor(model_output.shape, generator=generator, device=model_output.device)
264+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
265+
else:
266+
prev_sample = denoised
250267

251-
# upon completion increase step index by one
252-
self._step_index += 1
268+
# upon completion increase step index by one
269+
self._step_index += 1
253270

254-
if not return_dict:
255-
return (prev_sample, denoised)
271+
if not return_dict:
272+
return (prev_sample, denoised)
256273

257-
return diffusers.schedulers.scheduling_lcm.LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
274+
return diffusers.schedulers.scheduling_lcm.LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
258275

259276

260277
diffusers.LCMScheduler.step = LCMScheduler_step

modules/dml/hijack/stablediffusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
5151
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
5252
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
5353
# select parameters corresponding to the currently considered timestep
54-
alphas[index].__str__() # synchronize DML device
54+
torch.dml.synchronize_tensor(alphas[index]) # DML synchronize
5555
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
5656
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
5757
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)

0 commit comments

Comments
 (0)