@@ -49,7 +49,7 @@ impl DPMSolverSinglestepScheduler {
4949 let lambda_t = alpha_t. log ( ) - sigma_t. log ( ) ;
5050
5151 let step = ( config. train_timesteps - 1 ) as f64 / inference_steps as f64 ;
52- // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep .py#L199-L204
52+ // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep .py#L172-L173
5353 let timesteps: Vec < usize > = ( 0 ..inference_steps + 1 )
5454 . map ( |i| ( i as f64 * step) . round ( ) as usize )
5555 // discards the 0.0 element
@@ -62,13 +62,15 @@ impl DPMSolverSinglestepScheduler {
6262 model_outputs. push ( Tensor :: new ( ) ) ;
6363 }
6464
65+ let order_list = get_order_list ( inference_steps, config. solver_order , true ) ;
66+
6567 Self {
6668 alphas_cumprod : Vec :: < f64 > :: from ( alphas_cumprod) ,
6769 alpha_t : Vec :: < f64 > :: from ( alpha_t) ,
6870 sigma_t : Vec :: < f64 > :: from ( sigma_t) ,
6971 lambda_t : Vec :: < f64 > :: from ( lambda_t) ,
7072 init_noise_sigma : 1. ,
71- order_list : get_order_list ( inference_steps , config . solver_order , false ) ,
73+ order_list,
7274 model_outputs,
7375 timesteps,
7476 config,
@@ -292,6 +294,12 @@ impl DPMSolverSinglestepScheduler {
292294 self . timesteps . as_slice ( )
293295 }
294296
297+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
298+ /// depending on the current timestep.
299+ pub fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
300+ sample
301+ }
302+
295303 /// Step function propagating the sample with the singlestep DPM-Solver
296304 ///
297305 /// # Arguments
@@ -311,7 +319,7 @@ impl DPMSolverSinglestepScheduler {
311319 self . model_outputs [ i] = self . model_outputs [ i + 1 ] . shallow_clone ( ) ;
312320 }
313321 let m = self . model_outputs . len ( ) ;
314- self . model_outputs [ m - 1 ] = model_output;
322+ self . model_outputs [ m - 1 ] = model_output. shallow_clone ( ) ;
315323
316324 let order = self . order_list [ step_index] ;
317325
@@ -320,7 +328,7 @@ impl DPMSolverSinglestepScheduler {
320328 self . sample = Some ( sample. shallow_clone ( ) ) ;
321329 } ;
322330
323- let prev_sample = match order {
331+ match order {
324332 1 => self . dpm_solver_first_order_update (
325333 & self . model_outputs [ self . model_outputs . len ( ) - 1 ] ,
326334 timestep,
@@ -331,7 +339,7 @@ impl DPMSolverSinglestepScheduler {
331339 & self . model_outputs ,
332340 [ self . timesteps [ step_index - 1 ] , self . timesteps [ step_index] ] ,
333341 prev_timestep,
334- self . sample . as_ref ( ) . unwrap ( ) ,
342+ & self . sample . as_ref ( ) . unwrap ( ) ,
335343 ) ,
336344 3 => self . singlestep_dpm_solver_third_order_update (
337345 & self . model_outputs ,
@@ -341,14 +349,12 @@ impl DPMSolverSinglestepScheduler {
341349 self . timesteps [ step_index] ,
342350 ] ,
343351 prev_timestep,
344- self . sample . as_ref ( ) . unwrap ( ) ,
352+ & self . sample . as_ref ( ) . unwrap ( ) ,
345353 ) ,
346354 _ => {
347355 panic ! ( "invalid order" ) ;
348356 }
349- } ;
350-
351- prev_sample
357+ }
352358 }
353359
354360 pub fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
0 commit comments