1
+ from typing import Optional , Union , Tuple
1
2
import torch
2
3
import diffusers
3
4
import diffusers .utils .torch_utils
4
- from typing import Optional , Union , Tuple
5
5
6
6
7
7
def PNDMScheduler__get_prev_sample (self , sample : torch .FloatTensor , timestep , prev_timestep , model_output ):
@@ -17,7 +17,7 @@ def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, pr
17
17
# sample -> x_t
18
18
# model_output -> e_θ(x_t, t)
19
19
# prev_sample -> x_(t−δ)
20
- sample . __str__ ( ) # PNDM Sampling does not work without 'stringify'. (because it depends on PLMS)
20
+ torch . dml . synchronize_tensor ( sample ) # DML synchronize
21
21
alpha_prod_t = self .alphas_cumprod [timestep ]
22
22
alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
23
23
beta_prod_t = 1 - alpha_prod_t
@@ -53,51 +53,68 @@ def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, pr
53
53
54
54
55
55
def UniPCMultistepScheduler_multistep_uni_p_bh_update (
56
- self : diffusers . UniPCMultistepScheduler ,
56
+ self ,
57
57
model_output : torch .FloatTensor ,
58
- prev_timestep : int ,
59
- sample : torch .FloatTensor ,
60
- order : int ,
58
+ * args ,
59
+ sample : torch .FloatTensor = None ,
60
+ order : int = None ,
61
+ ** kwargs ,
61
62
) -> torch .FloatTensor :
62
63
"""
63
64
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
64
65
65
66
Args:
66
67
model_output (`torch.FloatTensor`):
67
- direct outputs from learned diffusion model at the current timestep.
68
- prev_timestep (`int`): previous discrete timestep in the diffusion chain.
68
+ The direct output from the learned diffusion model at the current timestep.
69
+ prev_timestep (`int`):
70
+ The previous discrete timestep in the diffusion chain.
69
71
sample (`torch.FloatTensor`):
70
- current instance of sample being created by diffusion process.
71
- order (`int`): the order of UniP at this step, also the p in UniPC-p.
72
+ A current instance of a sample created by the diffusion process.
73
+ order (`int`):
74
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
72
75
73
76
Returns:
74
- `torch.FloatTensor`: the sample tensor at the previous timestep.
77
+ `torch.FloatTensor`:
78
+ The sample tensor at the previous timestep.
75
79
"""
76
- timestep_list = self .timestep_list
80
+ if sample is None :
81
+ if len (args ) > 1 :
82
+ sample = args [1 ]
83
+ else :
84
+ raise ValueError (" missing `sample` as a required keyward argument" )
85
+ if order is None :
86
+ if len (args ) > 2 :
87
+ order = args [2 ]
88
+ else :
89
+ raise ValueError (" missing `order` as a required keyward argument" )
77
90
model_output_list = self .model_outputs
78
91
79
- s0 , t = self .timestep_list [- 1 ], prev_timestep
92
+ s0 = self .timestep_list [- 1 ]
80
93
m0 = model_output_list [- 1 ]
81
94
x = sample
82
95
83
96
if self .solver_p :
84
97
x_t = self .solver_p .step (model_output , s0 , x ).prev_sample
85
98
return x_t
86
99
87
- sample .__str__ () # UniPC Sampling does not work without 'stringify'.
88
- lambda_t , lambda_s0 = self .lambda_t [t ], self .lambda_t [s0 ]
89
- alpha_t , alpha_s0 = self .alpha_t [t ], self .alpha_t [s0 ]
90
- sigma_t , sigma_s0 = self .sigma_t [t ], self .sigma_t [s0 ]
100
+ torch .dml .synchronize_tensor (sample ) # DML synchronize
101
+ sigma_t , sigma_s0 = self .sigmas [self .step_index + 1 ], self .sigmas [self .step_index ]
102
+ alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
103
+ alpha_s0 , sigma_s0 = self ._sigma_to_alpha_sigma_t (sigma_s0 )
104
+
105
+ lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
106
+ lambda_s0 = torch .log (alpha_s0 ) - torch .log (sigma_s0 )
91
107
92
108
h = lambda_t - lambda_s0
93
109
device = sample .device
94
110
95
111
rks = []
96
112
D1s = []
97
113
for i in range (1 , order ):
98
- si = timestep_list [ - ( i + 1 )]
114
+ si = self . step_index - i
99
115
mi = model_output_list [- (i + 1 )]
100
- lambda_si = self .lambda_t [si ]
116
+ alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (self .sigmas [si ])
117
+ lambda_si = torch .log (alpha_si ) - torch .log (sigma_si )
101
118
rk = (lambda_si - lambda_s0 ) / h
102
119
rks .append (rk )
103
120
D1s .append ((mi - m0 ) / rk )
@@ -143,14 +160,14 @@ def UniPCMultistepScheduler_multistep_uni_p_bh_update(
143
160
if self .predict_x0 :
144
161
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
145
162
if D1s is not None :
146
- pred_res = torch .einsum ("k,bkchw->bchw " , rhos_p , D1s )
163
+ pred_res = torch .einsum ("k,bkc...->bc... " , rhos_p , D1s )
147
164
else :
148
165
pred_res = 0
149
166
x_t = x_t_ - alpha_t * B_h * pred_res
150
167
else :
151
168
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
152
169
if D1s is not None :
153
- pred_res = torch .einsum ("k,bkchw->bchw " , rhos_p , D1s )
170
+ pred_res = torch .einsum ("k,bkc...->bc... " , rhos_p , D1s )
154
171
else :
155
172
pred_res = 0
156
173
x_t = x_t_ - sigma_t * B_h * pred_res
@@ -170,91 +187,91 @@ def LCMScheduler_step(
170
187
generator : Optional [torch .Generator ] = None ,
171
188
return_dict : bool = True ,
172
189
) -> Union [diffusers .schedulers .scheduling_lcm .LCMSchedulerOutput , Tuple ]:
173
- """
174
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
175
- process from the learned model outputs (most often the predicted noise).
176
-
177
- Args:
178
- model_output (`torch.FloatTensor`):
179
- The direct output from learned diffusion model.
180
- timestep (`float`):
181
- The current discrete timestep in the diffusion chain.
182
- sample (`torch.FloatTensor`):
183
- A current instance of a sample created by the diffusion process.
184
- generator (`torch.Generator`, *optional*):
185
- A random number generator.
186
- return_dict (`bool`, *optional*, defaults to `True`):
187
- Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
188
- Returns:
189
- [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
190
- If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
191
- tuple is returned where the first element is the sample tensor.
192
- """
193
- if self .num_inference_steps is None :
194
- raise ValueError (
195
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
196
- )
197
-
198
- if self .step_index is None :
199
- self ._init_step_index (timestep )
200
-
201
- # 1. get previous step value
202
- prev_step_index = self .step_index + 1
203
- if prev_step_index < len (self .timesteps ):
204
- prev_timestep = self .timesteps [prev_step_index ]
205
- else :
206
- prev_timestep = timestep
207
-
208
- # 2. compute alphas, betas
209
- sample . __str__ ()
210
- alpha_prod_t = self .alphas_cumprod [timestep ]
211
- alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
212
-
213
- beta_prod_t = 1 - alpha_prod_t
214
- beta_prod_t_prev = 1 - alpha_prod_t_prev
215
-
216
- # 3. Get scalings for boundary conditions
217
- c_skip , c_out = self .get_scalings_for_boundary_condition_discrete (timestep )
218
-
219
- # 4. Compute the predicted original sample x_0 based on the model parameterization
220
- if self .config .prediction_type == "epsilon" : # noise-prediction
221
- predicted_original_sample = (sample - beta_prod_t .sqrt () * model_output ) / alpha_prod_t .sqrt ()
222
- elif self .config .prediction_type == "sample" : # x-prediction
223
- predicted_original_sample = model_output
224
- elif self .config .prediction_type == "v_prediction" : # v-prediction
225
- predicted_original_sample = alpha_prod_t .sqrt () * sample - beta_prod_t .sqrt () * model_output
226
- else :
227
- raise ValueError (
228
- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample` or"
229
- " `v_prediction` for `LCMScheduler`."
230
- )
231
-
232
- # 5. Clip or threshold "predicted x_0"
233
- if self .config .thresholding :
234
- predicted_original_sample = self ._threshold_sample (predicted_original_sample )
235
- elif self .config .clip_sample :
236
- predicted_original_sample = predicted_original_sample .clamp (
237
- - self .config .clip_sample_range , self .config .clip_sample_range
238
- )
239
-
240
- # 6. Denoise model output using boundary conditions
241
- denoised = c_out * predicted_original_sample + c_skip * sample
242
-
243
- # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
244
- # Noise is not used for one-step sampling.
245
- if len (self .timesteps ) > 1 :
246
- noise = diffusers .utils .torch_utils .randn_tensor (model_output .shape , generator = generator , device = model_output .device )
247
- prev_sample = alpha_prod_t_prev .sqrt () * denoised + beta_prod_t_prev .sqrt () * noise
248
- else :
249
- prev_sample = denoised
190
+ """
191
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
192
+ process from the learned model outputs (most often the predicted noise).
193
+
194
+ Args:
195
+ model_output (`torch.FloatTensor`):
196
+ The direct output from learned diffusion model.
197
+ timestep (`float`):
198
+ The current discrete timestep in the diffusion chain.
199
+ sample (`torch.FloatTensor`):
200
+ A current instance of a sample created by the diffusion process.
201
+ generator (`torch.Generator`, *optional*):
202
+ A random number generator.
203
+ return_dict (`bool`, *optional*, defaults to `True`):
204
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
205
+ Returns:
206
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
207
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
208
+ tuple is returned where the first element is the sample tensor.
209
+ """
210
+ if self .num_inference_steps is None :
211
+ raise ValueError (
212
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
213
+ )
214
+
215
+ if self .step_index is None :
216
+ self ._init_step_index (timestep )
217
+
218
+ # 1. get previous step value
219
+ prev_step_index = self .step_index + 1
220
+ if prev_step_index < len (self .timesteps ):
221
+ prev_timestep = self .timesteps [prev_step_index ]
222
+ else :
223
+ prev_timestep = timestep
224
+
225
+ # 2. compute alphas, betas
226
+ torch . dml . synchronize_tensor ( sample ) # DML synchronize
227
+ alpha_prod_t = self .alphas_cumprod [timestep ]
228
+ alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
229
+
230
+ beta_prod_t = 1 - alpha_prod_t
231
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
232
+
233
+ # 3. Get scalings for boundary conditions
234
+ c_skip , c_out = self .get_scalings_for_boundary_condition_discrete (timestep )
235
+
236
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
237
+ if self .config .prediction_type == "epsilon" : # noise-prediction
238
+ predicted_original_sample = (sample - beta_prod_t .sqrt () * model_output ) / alpha_prod_t .sqrt ()
239
+ elif self .config .prediction_type == "sample" : # x-prediction
240
+ predicted_original_sample = model_output
241
+ elif self .config .prediction_type == "v_prediction" : # v-prediction
242
+ predicted_original_sample = alpha_prod_t .sqrt () * sample - beta_prod_t .sqrt () * model_output
243
+ else :
244
+ raise ValueError (
245
+ f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample` or"
246
+ " `v_prediction` for `LCMScheduler`."
247
+ )
248
+
249
+ # 5. Clip or threshold "predicted x_0"
250
+ if self .config .thresholding :
251
+ predicted_original_sample = self ._threshold_sample (predicted_original_sample )
252
+ elif self .config .clip_sample :
253
+ predicted_original_sample = predicted_original_sample .clamp (
254
+ - self .config .clip_sample_range , self .config .clip_sample_range
255
+ )
256
+
257
+ # 6. Denoise model output using boundary conditions
258
+ denoised = c_out * predicted_original_sample + c_skip * sample
259
+
260
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
261
+ # Noise is not used for one-step sampling.
262
+ if len (self .timesteps ) > 1 :
263
+ noise = diffusers .utils .torch_utils .randn_tensor (model_output .shape , generator = generator , device = model_output .device )
264
+ prev_sample = alpha_prod_t_prev .sqrt () * denoised + beta_prod_t_prev .sqrt () * noise
265
+ else :
266
+ prev_sample = denoised
250
267
251
- # upon completion increase step index by one
252
- self ._step_index += 1
268
+ # upon completion increase step index by one
269
+ self ._step_index += 1
253
270
254
- if not return_dict :
255
- return (prev_sample , denoised )
271
+ if not return_dict :
272
+ return (prev_sample , denoised )
256
273
257
- return diffusers .schedulers .scheduling_lcm .LCMSchedulerOutput (prev_sample = prev_sample , denoised = denoised )
274
+ return diffusers .schedulers .scheduling_lcm .LCMSchedulerOutput (prev_sample = prev_sample , denoised = denoised )
258
275
259
276
260
277
diffusers .LCMScheduler .step = LCMScheduler_step
0 commit comments