import torch
import os
import PIL

from typing import List, Optional, Union
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from PIL import Image
from diffusers.utils import logging

VECTOR_DATA_FOLDER = "vector_data"
VECTOR_DATA_DICT = "vector_data"

logger = logging.get_logger(__name__)

def get_ddpm_inversion_scheduler(
    scheduler,
    step_function,
    config,
    timesteps,
    save_timesteps,
    latents,
    x_ts,
    x_ts_c_hat,
    save_intermediate_results,
    pipe,
    x_0,
    v1s_images,
    v2s_images,
    deltas_images,
    v1_x0s,
    v2_x0s,
    deltas_x0s,
    folder_name,
    image_name,
    time_measure_n,
):
    def step(
        model_output: torch.FloatTensor,
        timestep: int,
        sample: torch.FloatTensor,
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
        generator=None,
        variance_noise: Optional[torch.FloatTensor] = None,
        return_dict: bool = True,
    ):
        # if scheduler.is_save:
        # start = timer()
        res_inv =  step_save_latents(
            scheduler,
            model_output[:1, :, :, :],
            timestep,
            sample[:1, :, :, :],
            eta,
            use_clipped_model_output,
            generator,
            variance_noise,
            return_dict,
        )
        # end = timer()
        # print(f"Run Time Inv: {end - start}")

        res_inf = step_use_latents(
            scheduler,
            model_output[1:, :, :, :],
            timestep,
            sample[1:, :, :, :],
            eta,
            use_clipped_model_output,
            generator,
            variance_noise,
            return_dict,
        )
        # res = res_inv
        res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
        return res
        # return res

    scheduler.step_function = step_function
    scheduler.is_save = True
    scheduler._timesteps = timesteps
    scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
    scheduler._config = config
    scheduler.latents = latents
    scheduler.x_ts = x_ts
    scheduler.x_ts_c_hat = x_ts_c_hat
    scheduler.step = step
    scheduler.save_intermediate_results = save_intermediate_results
    scheduler.pipe = pipe
    scheduler.v1s_images = v1s_images
    scheduler.v2s_images = v2s_images
    scheduler.deltas_images = deltas_images
    scheduler.v1_x0s = v1_x0s
    scheduler.v2_x0s = v2_x0s
    scheduler.deltas_x0s = deltas_x0s
    scheduler.clean_step_run = False
    scheduler.x_0s = create_xts(
        config.noise_shift_delta,
        config.noise_timesteps,
        config.clean_step_timestep,
        None,
        pipe.scheduler,
        timesteps,
        x_0,
        no_add_noise=True,
    )
    scheduler.folder_name = folder_name
    scheduler.image_name = image_name
    scheduler.p_to_p = False
    scheduler.p_to_p_replace = False
    scheduler.time_measure_n = time_measure_n
    return scheduler

def step_save_latents(
    self,
    model_output: torch.FloatTensor,
    timestep: int,
    sample: torch.FloatTensor,
    eta: float = 0.0,
    use_clipped_model_output: bool = False,
    generator=None,
    variance_noise: Optional[torch.FloatTensor] = None,
    return_dict: bool = True,
):
    # print(self._save_timesteps)
    # timestep_index = map_timpstep_to_index[timestep]
    # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
    timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
    next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
    u_hat_t = self.step_function(
        model_output=model_output,
        timestep=timestep,
        sample=sample,
        eta=eta,
        use_clipped_model_output=use_clipped_model_output,
        generator=generator,
        variance_noise=variance_noise,
        return_dict=False,
        scheduler=self,
    )

    x_t_minus_1 = self.x_ts[next_timestep_index]
    self.x_ts_c_hat.append(u_hat_t)

    z_t = x_t_minus_1 - u_hat_t
    self.latents.append(z_t)
    z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)

    x_t_minus_1_predicted = u_hat_t + z_t

    if not return_dict:
        return (x_t_minus_1_predicted,)

    return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)

def step_use_latents(
    self,
    model_output: torch.FloatTensor,
    timestep: int,
    sample: torch.FloatTensor,
    eta: float = 0.0,
    use_clipped_model_output: bool = False,
    generator=None,
    variance_noise: Optional[torch.FloatTensor] = None,
    return_dict: bool = True,
):
    # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
    timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
    next_timestep_index = (
        timestep_index + 1 if not self.clean_step_run else -1
    )
    z_t = self.latents[next_timestep_index]  # + 1 because latents[0] is X_T

    _, normalize_coefficient = normalize(
        z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
        timestep_index,
        self._config.max_norm_zs,
    )

    if normalize_coefficient == 0:
        eta = 0

    # eta = normalize_coefficient

    x_t_hat_c_hat = self.step_function(
        model_output=model_output,
        timestep=timestep,
        sample=sample,
        eta=eta,
        use_clipped_model_output=use_clipped_model_output,
        generator=generator,
        variance_noise=variance_noise,
        return_dict=False,
        scheduler=self,
    )

    w1 = self._config.ws1[timestep_index]
    w2 = self._config.ws2[timestep_index]

    x_t_minus_1_exact = self.x_ts[next_timestep_index]
    x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)

    x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
    if self._config.breakdown == "x_t_c_hat":
        raise NotImplementedError("breakdown x_t_c_hat not implemented yet")

    # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
    x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)

    # if self._config.breakdown == "x_t_c_hat":
    #     v1 = x_t_hat_c_hat - x_t_c_hat
    #     v2 = x_t_c_hat - x_t_c
    if (
        self._config.breakdown == "x_t_hat_c"
        or self._config.breakdown == "x_t_hat_c_with_zeros"
    ):
        zero_index_reconstruction = 1 if not self.time_measure_n else 0
        edit_prompts_num = (
            (model_output.size(0) - zero_index_reconstruction) // 3
            if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
            else (model_output.size(0) - zero_index_reconstruction) // 2
        )
        x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
        edit_images_indices = (
            edit_prompts_num + zero_index_reconstruction,
            (
                model_output.size(0)
                if self._config.breakdown == "x_t_hat_c"
                else zero_index_reconstruction + 2 * edit_prompts_num
            ),
        )
        x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
        x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
            x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
        ]
        v1 = x_t_hat_c_hat - x_t_hat_c
        v2 = x_t_hat_c - normalize_coefficient * x_t_c
        if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
            path = os.path.join(
                self.folder_name,
                VECTOR_DATA_FOLDER,
                self.image_name,
            )
            if not hasattr(self, VECTOR_DATA_DICT):
                os.makedirs(path, exist_ok=True)
                self.vector_data = dict()

            x_t_0 = x_t_c_hat[1]
            empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
            x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]

            self.vector_data[timestep.item()] = dict()
            self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
                edit_images_indices[0] : edit_images_indices[1]
            ]
            self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
            self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
            self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
            self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
                edit_images_indices[0] : edit_images_indices[1]
            ]
            self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
                0
            ].expand_as(x_t_hat_0)
            self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
                next_timestep_index
            ].expand_as(x_t_hat_0)

    else:  # no breakdown
        v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
        v2 = 0

    if self.save_intermediate_results and not self.p_to_p:
        delta = v1 + v2
        v1_plus_x0 = self.x_0s[next_timestep_index] + v1
        v2_plus_x0 = self.x_0s[next_timestep_index] + v2
        delta_plus_x0 = self.x_0s[next_timestep_index] + delta

        v1_images = decode_latents(v1, self.pipe)
        self.v1s_images.append(v1_images)
        v2_images = (
            decode_latents(v2, self.pipe)
            if self._config.breakdown != "no_breakdown"
            else [PIL.Image.new("RGB", (1, 1))]
        )
        self.v2s_images.append(v2_images)
        delta_images = decode_latents(delta, self.pipe)
        self.deltas_images.append(delta_images)
        v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
        self.v1_x0s.append(v1_plus_x0_images)
        v2_plus_x0_images = (
            decode_latents(v2_plus_x0, self.pipe)
            if self._config.breakdown != "no_breakdown"
            else [PIL.Image.new("RGB", (1, 1))]
        )
        self.v2_x0s.append(v2_plus_x0_images)
        delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
        self.deltas_x0s.append(delta_plus_x0_images)

    # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
    # if self._config.breakdown != "no_breakdown":
    #     print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
    #     print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")

    x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2

    if (
        self._config.breakdown == "x_t_hat_c"
        or self._config.breakdown == "x_t_hat_c_with_zeros"
    ):
        x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
            edit_images_indices[0] : edit_images_indices[1]
        ] # update x_t_hat_c to be x_t_hat_c_hat
        if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
            x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
                x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
            )
            self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
                edit_images_indices[0] : edit_images_indices[1]
            ]
            if timestep == self._timesteps[-1]:
                torch.save(
                    self.vector_data,
                    os.path.join(
                        path,
                        f"{VECTOR_DATA_DICT}.pt",
                    ),
                )
    # p_to_p_force_perfect_reconstruction
    if not self.time_measure_n:
        x_t_minus_1[0] = x_t_minus_1_exact[0]

    if not return_dict:
        return (x_t_minus_1,)

    return DDIMSchedulerOutput(
        prev_sample=x_t_minus_1,
        pred_original_sample=None,
    )

def create_xts(
    noise_shift_delta,
    noise_timesteps,
    clean_step_timestep,
    generator,
    scheduler,
    timesteps,
    x_0,
    no_add_noise=False,
):
    if noise_timesteps is None:
        noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
        noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]

    first_x_0_idx = len(noise_timesteps)
    for i in range(len(noise_timesteps)):
        if noise_timesteps[i] <= 0:
            first_x_0_idx = i
            break

    noise_timesteps = noise_timesteps[:first_x_0_idx]

    x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
    noise = (
        torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
            x_0.device
        )
        if not no_add_noise
        else torch.zeros_like(x_0_expanded)
    )
    x_ts = scheduler.add_noise(
        x_0_expanded,
        noise,
        torch.IntTensor(noise_timesteps),
    )
    x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
    x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
    x_ts += [x_0]
    if clean_step_timestep > 0:
        x_ts += [x_0]
    return x_ts

def normalize(
    z_t,
    i,
    max_norm_zs,
):
    max_norm = max_norm_zs[i]
    if max_norm < 0:
        return z_t, 1

    norm = torch.norm(z_t)
    if norm < max_norm:
        return z_t, 1

    coeff = max_norm / norm
    z_t = z_t * coeff
    return z_t, coeff

def decode_latents(latent, pipe):
    latent_img = pipe.vae.decode(
        latent / pipe.vae.config.scaling_factor, return_dict=False
    )[0]
    return pipe.image_processor.postprocess(latent_img, output_type="pil")

def deterministic_ddim_step(
    model_output: torch.FloatTensor,
    timestep: int,
    sample: torch.FloatTensor,
    eta: float = 0.0,
    use_clipped_model_output: bool = False,
    generator=None,
    variance_noise: Optional[torch.FloatTensor] = None,
    return_dict: bool = True,
    scheduler=None,
):

    if scheduler.num_inference_steps is None:
        raise ValueError(
            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
        )

    prev_timestep = (
        timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    )

    # 2. compute alphas, betas
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    alpha_prod_t_prev = (
        scheduler.alphas_cumprod[prev_timestep]
        if prev_timestep >= 0
        else scheduler.final_alpha_cumprod
    )

    beta_prod_t = 1 - alpha_prod_t

    if scheduler.config.prediction_type == "epsilon":
        pred_original_sample = (
            sample - beta_prod_t ** (0.5) * model_output
        ) / alpha_prod_t ** (0.5)
        pred_epsilon = model_output
    elif scheduler.config.prediction_type == "sample":
        pred_original_sample = model_output
        pred_epsilon = (
            sample - alpha_prod_t ** (0.5) * pred_original_sample
        ) / beta_prod_t ** (0.5)
    elif scheduler.config.prediction_type == "v_prediction":
        pred_original_sample = (alpha_prod_t**0.5) * sample - (
            beta_prod_t**0.5
        ) * model_output
        pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
    else:
        raise ValueError(
            f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
            " `v_prediction`"
        )

    # 4. Clip or threshold "predicted x_0"
    if scheduler.config.thresholding:
        pred_original_sample = scheduler._threshold_sample(pred_original_sample)
    elif scheduler.config.clip_sample:
        pred_original_sample = pred_original_sample.clamp(
            -scheduler.config.clip_sample_range,
            scheduler.config.clip_sample_range,
        )

    # 5. compute variance: "sigma_t(η)" -> see formula (16)
    # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
    variance = scheduler._get_variance(timestep, prev_timestep)
    std_dev_t = eta * variance ** (0.5)

    if use_clipped_model_output:
        # the pred_epsilon is always re-derived from the clipped x_0 in Glide
        pred_epsilon = (
            sample - alpha_prod_t ** (0.5) * pred_original_sample
        ) / beta_prod_t ** (0.5)

    # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
        0.5
    ) * pred_epsilon

    # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    prev_sample = (
        alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
    )
    return prev_sample


def deterministic_euler_step(
    model_output: torch.FloatTensor,
    timestep: Union[float, torch.FloatTensor],
    sample: torch.FloatTensor,
    eta,
    use_clipped_model_output,
    generator,
    variance_noise,
    return_dict,
    scheduler,
):
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
    process from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`torch.FloatTensor`):
            The direct output from learned diffusion model.
        timestep (`float`):
            The current discrete timestep in the diffusion chain.
        sample (`torch.FloatTensor`):
            A current instance of a sample created by the diffusion process.
        generator (`torch.Generator`, *optional*):
            A random number generator.
        return_dict (`bool`):
            Whether or not to return a
            [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.

    Returns:
        [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
            If return_dict is `True`,
            [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
            otherwise a tuple is returned where the first element is the sample tensor.

    """

    if (
        isinstance(timestep, int)
        or isinstance(timestep, torch.IntTensor)
        or isinstance(timestep, torch.LongTensor)
    ):
        raise ValueError(
            (
                "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                " one of the `scheduler.timesteps` as a timestep."
            ),
        )

    if scheduler.step_index is None:
        scheduler._init_step_index(timestep)

    sigma = scheduler.sigmas[scheduler.step_index]

    # Upcast to avoid precision issues when computing prev_sample
    sample = sample.to(torch.float32)

    # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
    if scheduler.config.prediction_type == "epsilon":
        pred_original_sample = sample - sigma * model_output
    elif scheduler.config.prediction_type == "v_prediction":
        # * c_out + input * c_skip
        pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
            sample / (sigma**2 + 1)
        )
    elif scheduler.config.prediction_type == "sample":
        raise NotImplementedError("prediction_type not implemented yet: sample")
    else:
        raise ValueError(
            f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
        )

    sigma_from = scheduler.sigmas[scheduler.step_index]
    sigma_to = scheduler.sigmas[scheduler.step_index + 1]
    sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5

    # 2. Convert to an ODE derivative
    derivative = (sample - pred_original_sample) / sigma

    dt = sigma_down - sigma

    prev_sample = sample + derivative * dt

    # Cast sample back to model compatible dtype
    prev_sample = prev_sample.to(model_output.dtype)

    # upon completion increase step index by one
    scheduler._step_index += 1

    return prev_sample


def deterministic_non_ancestral_euler_step(
    model_output: torch.FloatTensor,
    timestep: Union[float, torch.FloatTensor],
    sample: torch.FloatTensor,
    eta: float = 0.0,
    use_clipped_model_output: bool = False,
    s_churn: float = 0.0,
    s_tmin: float = 0.0,
    s_tmax: float = float("inf"),
    s_noise: float = 1.0,
    generator: Optional[torch.Generator] = None,
    variance_noise: Optional[torch.FloatTensor] = None,
    return_dict: bool = True,
    scheduler=None,
):
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
    process from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`torch.FloatTensor`):
            The direct output from learned diffusion model.
        timestep (`float`):
            The current discrete timestep in the diffusion chain.
        sample (`torch.FloatTensor`):
            A current instance of a sample created by the diffusion process.
        s_churn (`float`):
        s_tmin  (`float`):
        s_tmax  (`float`):
        s_noise (`float`, defaults to 1.0):
            Scaling factor for noise added to the sample.
        generator (`torch.Generator`, *optional*):
            A random number generator.
        return_dict (`bool`):
            Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
            tuple.

    Returns:
        [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
            If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
            returned, otherwise a tuple is returned where the first element is the sample tensor.
    """

    if (
        isinstance(timestep, int)
        or isinstance(timestep, torch.IntTensor)
        or isinstance(timestep, torch.LongTensor)
    ):
        raise ValueError(
            (
                "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                " one of the `scheduler.timesteps` as a timestep."
            ),
        )

    if not scheduler.is_scale_input_called:
        logger.warning(
            "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
            "See `StableDiffusionPipeline` for a usage example."
        )

    if scheduler.step_index is None:
        scheduler._init_step_index(timestep)

    # Upcast to avoid precision issues when computing prev_sample
    sample = sample.to(torch.float32)

    sigma = scheduler.sigmas[scheduler.step_index]

    gamma = (
        min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
        if s_tmin <= sigma <= s_tmax
        else 0.0
    )

    sigma_hat = sigma * (gamma + 1)

    # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
    # NOTE: "original_sample" should not be an expected prediction_type but is left in for
    # backwards compatibility
    if (
        scheduler.config.prediction_type == "original_sample"
        or scheduler.config.prediction_type == "sample"
    ):
        pred_original_sample = model_output
    elif scheduler.config.prediction_type == "epsilon":
        pred_original_sample = sample - sigma_hat * model_output
    elif scheduler.config.prediction_type == "v_prediction":
        # denoised = model_output * c_out + input * c_skip
        pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
            sample / (sigma**2 + 1)
        )
    else:
        raise ValueError(
            f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
        )

    # 2. Convert to an ODE derivative
    derivative = (sample - pred_original_sample) / sigma_hat

    dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat

    prev_sample = sample + derivative * dt

    # Cast sample back to model compatible dtype
    prev_sample = prev_sample.to(model_output.dtype)

    # upon completion increase step index by one
    scheduler._step_index += 1

    return prev_sample


def deterministic_ddpm_step(
    model_output: torch.FloatTensor,
    timestep: Union[float, torch.FloatTensor],
    sample: torch.FloatTensor,
    eta,
    use_clipped_model_output,
    generator,
    variance_noise,
    return_dict,
    scheduler,
):
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
    process from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`torch.FloatTensor`):
            The direct output from learned diffusion model.
        timestep (`float`):
            The current discrete timestep in the diffusion chain.
        sample (`torch.FloatTensor`):
            A current instance of a sample created by the diffusion process.
        generator (`torch.Generator`, *optional*):
            A random number generator.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.

    Returns:
        [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
            If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
            tuple is returned where the first element is the sample tensor.

    """
    t = timestep

    prev_t = scheduler.previous_timestep(t)

    if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
        "learned",
        "learned_range",
    ]:
        model_output, predicted_variance = torch.split(
            model_output, sample.shape[1], dim=1
        )
    else:
        predicted_variance = None

    # 1. compute alphas, betas
    alpha_prod_t = scheduler.alphas_cumprod[t]
    alpha_prod_t_prev = (
        scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
    )
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    current_alpha_t = alpha_prod_t / alpha_prod_t_prev
    current_beta_t = 1 - current_alpha_t

    # 2. compute predicted original sample from predicted noise also called
    # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
    if scheduler.config.prediction_type == "epsilon":
        pred_original_sample = (
            sample - beta_prod_t ** (0.5) * model_output
        ) / alpha_prod_t ** (0.5)
    elif scheduler.config.prediction_type == "sample":
        pred_original_sample = model_output
    elif scheduler.config.prediction_type == "v_prediction":
        pred_original_sample = (alpha_prod_t**0.5) * sample - (
            beta_prod_t**0.5
        ) * model_output
    else:
        raise ValueError(
            f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
            " `v_prediction`  for the DDPMScheduler."
        )

    # 3. Clip or threshold "predicted x_0"
    if scheduler.config.thresholding:
        pred_original_sample = scheduler._threshold_sample(pred_original_sample)
    elif scheduler.config.clip_sample:
        pred_original_sample = pred_original_sample.clamp(
            -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
        )

    # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
    # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
    pred_original_sample_coeff = (
        alpha_prod_t_prev ** (0.5) * current_beta_t
    ) / beta_prod_t
    current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t

    # 5. Compute predicted previous sample µ_t
    # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
    pred_prev_sample = (
        pred_original_sample_coeff * pred_original_sample
        + current_sample_coeff * sample
    )

    return pred_prev_sample