ldm add plms sampler
This commit is contained in:
parent
55197f2209
commit
35b92ba9de
@ -16,6 +16,7 @@ export default async function inpaint(
|
|||||||
fd.append('mask', mask)
|
fd.append('mask', mask)
|
||||||
|
|
||||||
fd.append('ldmSteps', settings.ldmSteps.toString())
|
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||||
|
fd.append('ldmSampler', settings.ldmSampler.toString())
|
||||||
fd.append('hdStrategy', settings.hdStrategy)
|
fd.append('hdStrategy', settings.hdStrategy)
|
||||||
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
||||||
fd.append(
|
fd.append(
|
||||||
|
@ -11,6 +11,11 @@ export enum HDStrategy {
|
|||||||
CROP = 'Crop',
|
CROP = 'Crop',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum LDMSampler {
|
||||||
|
ddim = 'ddim',
|
||||||
|
plms = 'plms',
|
||||||
|
}
|
||||||
|
|
||||||
function HDSettingBlock() {
|
function HDSettingBlock() {
|
||||||
const [setting, setSettingState] = useRecoilState(settingState)
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import React, { ReactNode } from 'react'
|
|||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import Selector from '../shared/Selector'
|
import Selector from '../shared/Selector'
|
||||||
|
import { LDMSampler } from './HDSettingBlock'
|
||||||
import NumberInputSetting from './NumberInputSetting'
|
import NumberInputSetting from './NumberInputSetting'
|
||||||
import SettingBlock from './SettingBlock'
|
import SettingBlock from './SettingBlock'
|
||||||
|
|
||||||
@ -19,6 +20,12 @@ function ModelSettingBlock() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onLDMSamplerChange = (value: LDMSampler) => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, ldmSampler: value }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const renderModelDesc = (
|
const renderModelDesc = (
|
||||||
name: string,
|
name: string,
|
||||||
paperUrl: string,
|
paperUrl: string,
|
||||||
@ -65,6 +72,19 @@ function ModelSettingBlock() {
|
|||||||
})
|
})
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
className="sub-setting-block"
|
||||||
|
title="Sampler"
|
||||||
|
input={
|
||||||
|
<Selector
|
||||||
|
width={80}
|
||||||
|
value={setting.ldmSampler as string}
|
||||||
|
options={Object.values(LDMSampler)}
|
||||||
|
onChange={val => onLDMSamplerChange(val as LDMSampler)}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ export default function ShortcutsModal() {
|
|||||||
/>
|
/>
|
||||||
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
||||||
<ShortCut content="Pan" keys={['Space & Drag']} />
|
<ShortCut content="Pan" keys={['Space & Drag']} />
|
||||||
<ShortCut content="View Original Image" keys={['Hold Tag']} />
|
<ShortCut content="View Original Image" keys={['Hold Tab']} />
|
||||||
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
|
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
|
||||||
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
||||||
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
all: unset;
|
all: unset;
|
||||||
flex: 1 0 auto;
|
flex: 1 0 auto;
|
||||||
border-radius: 0.5rem;
|
border-radius: 0.5rem;
|
||||||
padding: 0.4rem 0.8rem;
|
padding: 0 0.8rem;
|
||||||
outline: 1px solid var(--border-color);
|
outline: 1px solid var(--border-color);
|
||||||
|
height: 36px;
|
||||||
|
|
||||||
&:focus-visible {
|
&:focus-visible {
|
||||||
outline: 1px solid var(--yellow-accent);
|
outline: 1px solid var(--yellow-accent);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { atom } from 'recoil'
|
import { atom } from 'recoil'
|
||||||
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
import { HDStrategy, LDMSampler } from '../components/Settings/HDSettingBlock'
|
||||||
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||||
import { ToastState } from '../components/shared/Toast'
|
import { ToastState } from '../components/shared/Toast'
|
||||||
|
|
||||||
@ -43,6 +43,7 @@ export interface Settings {
|
|||||||
|
|
||||||
// For LDM
|
// For LDM
|
||||||
ldmSteps: number
|
ldmSteps: number
|
||||||
|
ldmSampler: LDMSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
export const settingStateDefault = {
|
export const settingStateDefault = {
|
||||||
@ -50,6 +51,7 @@ export const settingStateDefault = {
|
|||||||
runInpaintingManually: false,
|
runInpaintingManually: false,
|
||||||
model: AIModel.LAMA,
|
model: AIModel.LAMA,
|
||||||
ldmSteps: 50,
|
ldmSteps: 50,
|
||||||
|
ldmSampler: LDMSampler.plms,
|
||||||
hdStrategy: HDStrategy.RESIZE,
|
hdStrategy: HDStrategy.RESIZE,
|
||||||
hdStrategyResizeLimit: 2048,
|
hdStrategyResizeLimit: 2048,
|
||||||
hdStrategyCropTrigerSize: 2048,
|
hdStrategyCropTrigerSize: 2048,
|
||||||
|
193
lama_cleaner/model/ddim_sampler.py
Normal file
193
lama_cleaner/model/ddim_sampler.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear"):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(
|
||||||
|
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||||
|
):
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
|
ddim_discr_method=ddim_discretize,
|
||||||
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
|
# array([1])
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
||||||
|
assert (
|
||||||
|
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||||
|
), "alphas have to be defined for each timestep"
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer("betas", to_torch(self.model.betas))
|
||||||
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer(
|
||||||
|
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_one_minus_alphas_cumprod",
|
||||||
|
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recipm1_alphas_cumprod",
|
||||||
|
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||||
|
alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||||
|
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||||
|
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||||
|
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev)
|
||||||
|
/ (1 - self.alphas_cumprod)
|
||||||
|
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, steps, conditioning, batch_size, shape):
|
||||||
|
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
|
||||||
|
# samples: 1,3,128,128
|
||||||
|
return self.ddim_sampling(
|
||||||
|
conditioning,
|
||||||
|
size,
|
||||||
|
quantize_denoised=False,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=0,
|
||||||
|
temperature=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sampling(
|
||||||
|
self,
|
||||||
|
cond,
|
||||||
|
shape,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
||||||
|
timesteps = (
|
||||||
|
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
time_range = (
|
||||||
|
reversed(range(0, timesteps))
|
||||||
|
if ddim_use_original_steps
|
||||||
|
else np.flip(timesteps)
|
||||||
|
)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
outs = self.p_sample_ddim(
|
||||||
|
img,
|
||||||
|
cond,
|
||||||
|
ts,
|
||||||
|
index=index,
|
||||||
|
use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised,
|
||||||
|
temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
)
|
||||||
|
img, _ = outs
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
t,
|
||||||
|
index,
|
||||||
|
repeat_noise=False,
|
||||||
|
use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = (
|
||||||
|
self.model.alphas_cumprod_prev
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_alphas_prev
|
||||||
|
)
|
||||||
|
sqrt_one_minus_alphas = (
|
||||||
|
self.model.sqrt_one_minus_alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sqrt_one_minus_alphas
|
||||||
|
)
|
||||||
|
sigmas = (
|
||||||
|
self.model.ddim_sigmas_for_original_num_steps
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sigmas
|
||||||
|
)
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full(
|
||||||
|
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised: # 没用
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.0: # 没用
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
@ -5,17 +5,15 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
||||||
|
from lama_cleaner.model.plms_sampler import PLMSSampler
|
||||||
|
from lama_cleaner.schema import Config, LDMSampler
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
|
||||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||||
from lama_cleaner.model.utils import (
|
from lama_cleaner.model.utils import (
|
||||||
make_beta_schedule,
|
make_beta_schedule,
|
||||||
make_ddim_timesteps,
|
|
||||||
make_ddim_sampling_parameters,
|
|
||||||
noise_like,
|
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,7 +92,7 @@ class DDPM(nn.Module):
|
|||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
assert (
|
assert (
|
||||||
alphas_cumprod.shape[0] == self.num_timesteps
|
alphas_cumprod.shape[0] == self.num_timesteps
|
||||||
), "alphas have to be defined for each timestep"
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||||
@ -120,7 +118,7 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = (1 - self.v_posterior) * betas * (
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
||||||
1.0 - alphas_cumprod_prev
|
1.0 - alphas_cumprod_prev
|
||||||
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||||
@ -142,16 +140,16 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
if self.parameterization == "eps":
|
if self.parameterization == "eps":
|
||||||
lvlb_weights = self.betas ** 2 / (
|
lvlb_weights = self.betas ** 2 / (
|
||||||
2
|
2
|
||||||
* self.posterior_variance
|
* self.posterior_variance
|
||||||
* to_torch(alphas)
|
* to_torch(alphas)
|
||||||
* (1 - self.alphas_cumprod)
|
* (1 - self.alphas_cumprod)
|
||||||
)
|
)
|
||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
lvlb_weights = (
|
lvlb_weights = (
|
||||||
0.5
|
0.5
|
||||||
* np.sqrt(torch.Tensor(alphas_cumprod))
|
* np.sqrt(torch.Tensor(alphas_cumprod))
|
||||||
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("mu not supported")
|
raise NotImplementedError("mu not supported")
|
||||||
@ -221,192 +219,6 @@ class LatentDiffusion(DDPM):
|
|||||||
return x_recon
|
return x_recon
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
|
||||||
def __init__(self, model, schedule="linear"):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
|
||||||
self.schedule = schedule
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(
|
|
||||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
|
||||||
):
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
|
||||||
ddim_discr_method=ddim_discretize,
|
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
|
||||||
# array([1])
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
|
||||||
assert (
|
|
||||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
|
||||||
), "alphas have to be defined for each timestep"
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
||||||
|
|
||||||
self.register_buffer("betas", to_torch(self.model.betas))
|
|
||||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer(
|
|
||||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_one_minus_alphas_cumprod",
|
|
||||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recipm1_alphas_cumprod",
|
|
||||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ddim sampling parameters
|
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
||||||
alphacums=alphas_cumprod.cpu(),
|
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
|
||||||
eta=ddim_eta,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
|
||||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
|
||||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
|
||||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - self.alphas_cumprod_prev)
|
|
||||||
/ (1 - self.alphas_cumprod)
|
|
||||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self, steps, conditioning, batch_size, shape):
|
|
||||||
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
|
|
||||||
# samples: 1,3,128,128
|
|
||||||
return self.ddim_sampling(
|
|
||||||
conditioning,
|
|
||||||
size,
|
|
||||||
quantize_denoised=False,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=0,
|
|
||||||
temperature=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ddim_sampling(
|
|
||||||
self,
|
|
||||||
cond,
|
|
||||||
shape,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
quantize_denoised=False,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
):
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = shape[0]
|
|
||||||
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
|
||||||
timesteps = (
|
|
||||||
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
|
||||||
)
|
|
||||||
|
|
||||||
time_range = (
|
|
||||||
reversed(range(0, timesteps))
|
|
||||||
if ddim_use_original_steps
|
|
||||||
else np.flip(timesteps)
|
|
||||||
)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
outs = self.p_sample_ddim(
|
|
||||||
img,
|
|
||||||
cond,
|
|
||||||
ts,
|
|
||||||
index=index,
|
|
||||||
use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
)
|
|
||||||
img, _ = outs
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_ddim(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
c,
|
|
||||||
t,
|
|
||||||
index,
|
|
||||||
repeat_noise=False,
|
|
||||||
use_original_steps=False,
|
|
||||||
quantize_denoised=False,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = (
|
|
||||||
self.model.alphas_cumprod_prev
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_alphas_prev
|
|
||||||
)
|
|
||||||
sqrt_one_minus_alphas = (
|
|
||||||
self.model.sqrt_one_minus_alphas_cumprod
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_sqrt_one_minus_alphas
|
|
||||||
)
|
|
||||||
sigmas = (
|
|
||||||
self.model.ddim_sigmas_for_original_num_steps
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_sigmas
|
|
||||||
)
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full(
|
|
||||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised: # 没用
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.0: # 没用
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
|
|
||||||
def load_jit_model(url, device):
|
def load_jit_model(url, device):
|
||||||
model_path = download_model(url)
|
model_path = download_model(url)
|
||||||
logger.info(f"Load LDM model from: {model_path}")
|
logger.info(f"Load LDM model from: {model_path}")
|
||||||
@ -432,8 +244,7 @@ class LDM(InpaintModel):
|
|||||||
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
||||||
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
||||||
|
|
||||||
model = LatentDiffusion(self.diffusion_model, device)
|
self.model = LatentDiffusion(self.diffusion_model, device)
|
||||||
self.sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
@ -454,6 +265,13 @@ class LDM(InpaintModel):
|
|||||||
# image [1,3,512,512] float32
|
# image [1,3,512,512] float32
|
||||||
# mask: [1,1,512,512] float32
|
# mask: [1,1,512,512] float32
|
||||||
# masked_image: [1,3,512,512] float32
|
# masked_image: [1,3,512,512] float32
|
||||||
|
if config.ldm_sampler == LDMSampler.ddim:
|
||||||
|
sampler = DDIMSampler(self.model)
|
||||||
|
elif config.ldm_sampler == LDMSampler.plms:
|
||||||
|
sampler = PLMSSampler(self.model)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
steps = config.ldm_steps
|
steps = config.ldm_steps
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
@ -465,7 +283,6 @@ class LDM(InpaintModel):
|
|||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||||
masked_image = (1 - mask) * image
|
masked_image = (1 - mask) * image
|
||||||
|
|
||||||
image = self._norm(image)
|
|
||||||
mask = self._norm(mask)
|
mask = self._norm(mask)
|
||||||
masked_image = self._norm(masked_image)
|
masked_image = self._norm(masked_image)
|
||||||
|
|
||||||
@ -476,7 +293,7 @@ class LDM(InpaintModel):
|
|||||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||||
|
|
||||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||||
samples_ddim = self.sampler.sample(
|
samples_ddim = sampler.sample(
|
||||||
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
225
lama_cleaner/model/plms_sampler.py
Normal file
225
lama_cleaner/model/plms_sampler.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
if ddim_eta != 0:
|
||||||
|
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta, verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
steps,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=False,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps, t_next=ts_next)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
@ -9,8 +9,14 @@ class HDStrategy(str, Enum):
|
|||||||
CROP = 'Crop'
|
CROP = 'Crop'
|
||||||
|
|
||||||
|
|
||||||
|
class LDMSampler(str, Enum):
|
||||||
|
ddim = 'ddim'
|
||||||
|
plms = 'plms'
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
ldm_steps: int
|
ldm_steps: int
|
||||||
|
ldm_sampler: str
|
||||||
hd_strategy: str
|
hd_strategy: str
|
||||||
hd_strategy_crop_margin: int
|
hd_strategy_crop_margin: int
|
||||||
hd_strategy_crop_trigger_size: int
|
hd_strategy_crop_trigger_size: int
|
||||||
|
@ -93,6 +93,7 @@ def process():
|
|||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
ldm_steps=form["ldmSteps"],
|
ldm_steps=form["ldmSteps"],
|
||||||
|
ldm_sampler=form["ldmSampler"],
|
||||||
hd_strategy=form["hdStrategy"],
|
hd_strategy=form["hdStrategy"],
|
||||||
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
|
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
|
||||||
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
|
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
|
||||||
|
Loading…
Reference in New Issue
Block a user