11use std:: iter:: repeat;
22
3- use super :: { betas_for_alpha_bar, BetaSchedule , PredictionType } ;
3+ use super :: {
4+ betas_for_alpha_bar,
5+ dpmsolver:: { DPMSolverAlgorithmType , DPMSolverSchedulerConfig , DPMSolverType } ,
6+ BetaSchedule , PredictionType ,
7+ } ;
48use tch:: { kind, Kind , Tensor } ;
59
6- /// The algorithm type for the solver.
7- ///
8- #[ derive( Default , Debug , Clone , PartialEq , Eq ) ]
9- pub enum DPMSolverAlgorithmType {
10- /// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
11- #[ default]
12- DPMSolverPlusPlus ,
13- /// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
14- DPMSolver ,
15- }
16-
17- /// The solver type for the second-order solver.
18- /// The solver type slightly affects the sample quality, especially for
19- /// small number of steps.
20- #[ derive( Default , Debug , Clone , PartialEq , Eq ) ]
21- pub enum DPMSolverType {
22- #[ default]
23- Midpoint ,
24- Heun ,
25- }
26-
27- #[ derive( Debug , Clone ) ]
28- pub struct DPMSolverSinglestepSchedulerConfig {
29- /// The value of beta at the beginning of training.
30- pub beta_start : f64 ,
31- /// The value of beta at the end of training.
32- pub beta_end : f64 ,
33- /// How beta evolved during training.
34- pub beta_schedule : BetaSchedule ,
35- /// number of diffusion steps used to train the model.
36- pub train_timesteps : usize ,
37- /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
38- /// sampling, and `solver_order=3` for unconditional sampling.
39- pub solver_order : usize ,
40- /// prediction type of the scheduler function
41- pub prediction_type : PredictionType ,
42- /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
43- /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
44- pub sample_max_value : f32 ,
45- /// The algorithm type for the solver
46- pub algorithm_type : DPMSolverAlgorithmType ,
47- /// The solver type for the second-order solver.
48- pub solver_type : DPMSolverType ,
49- /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
50- /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
51- pub lower_order_final : bool ,
52- }
53-
54- impl Default for DPMSolverSinglestepSchedulerConfig {
55- fn default ( ) -> Self {
56- Self {
57- beta_start : 0.0001 ,
58- beta_end : 0.02 ,
59- train_timesteps : 1000 ,
60- beta_schedule : BetaSchedule :: Linear ,
61- solver_order : 2 ,
62- prediction_type : PredictionType :: Epsilon ,
63- sample_max_value : 1.0 ,
64- algorithm_type : DPMSolverAlgorithmType :: DPMSolverPlusPlus ,
65- solver_type : DPMSolverType :: Midpoint ,
66- lower_order_final : true ,
67- }
68- }
69- }
70-
7110pub struct DPMSolverSinglestepScheduler {
7211 alphas_cumprod : Vec < f64 > ,
7312 alpha_t : Vec < f64 > ,
7413 sigma_t : Vec < f64 > ,
7514 lambda_t : Vec < f64 > ,
7615 init_noise_sigma : f64 ,
7716 order_list : Vec < usize > ,
17+ /// Direct outputs from learned diffusion model at current and latter timesteps
7818 model_outputs : Vec < Tensor > ,
19+ /// List of current discrete timesteps in the diffusion chain
7920 timesteps : Vec < usize > ,
21+ /// Current instance of sample being created by diffusion process
8022 sample : Option < Tensor > ,
81- pub config : DPMSolverSinglestepSchedulerConfig ,
23+ pub config : DPMSolverSchedulerConfig ,
8224}
8325
8426impl DPMSolverSinglestepScheduler {
85- pub fn new ( inference_steps : usize , config : DPMSolverSinglestepSchedulerConfig ) -> Self {
27+ pub fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
8628 let betas = match config. beta_schedule {
8729 BetaSchedule :: ScaledLinear => Tensor :: linspace (
8830 config. beta_start . sqrt ( ) ,
@@ -462,8 +404,6 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) ->
462404 . chain ( [ & [ 1 ] [ ..] ] )
463405 . flatten ( )
464406 . map ( |v| * v)
465- . collect ( )
466- }
467407 } else if solver_order == 1 {
468408 repeat ( & [ 1 ] [ ..] ) . take ( steps) . flatten ( ) . map ( |v| * v) . collect ( )
469409 } else {
@@ -473,11 +413,9 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) ->
473413 if solver_order == 3 {
474414 repeat ( & [ 1 , 2 , 3 ] [ ..] ) . take ( steps / 3 ) . flatten ( ) . map ( |v| * v) . collect ( )
475415 } else if solver_order == 2 {
476- repeat ( & [ 1 , 2 ] [ ..] ) . take ( steps / 2 ) . flatten ( ) . map ( |v| * v ) . collect ( )
416+ repeat ( dbg ! ( & [ 1 , 2 ] [ ..] ) ) . take ( dbg ! ( steps / 2 ) ) . flatten ( ) . map ( |v| dbg ! ( * v ) ) . collect ( )
477417 } else if solver_order == 1 {
478418 repeat ( & [ 1 ] [ ..] ) . take ( steps) . flatten ( ) . map ( |v| * v) . collect ( )
479- } else {
480- panic ! ( "invalid solver_order" ) ;
481419 }
482420 }
483421}
0 commit comments