enable fp16 for ldm by default
This commit is contained in:
parent
e369a2f079
commit
e4a6c91f4a
@ -11,8 +11,13 @@ torch.manual_seed(42)
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||||
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
from lama_cleaner.model.utils import (
|
||||||
timestep_embedding
|
make_beta_schedule,
|
||||||
|
make_ddim_timesteps,
|
||||||
|
make_ddim_sampling_parameters,
|
||||||
|
noise_like,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||||
"LDM_ENCODE_MODEL_URL",
|
"LDM_ENCODE_MODEL_URL",
|
||||||
@ -32,18 +37,20 @@ LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
|||||||
|
|
||||||
class DDPM(nn.Module):
|
class DDPM(nn.Module):
|
||||||
# classic DDPM with Gaussian diffusion, in image space
|
# classic DDPM with Gaussian diffusion, in image space
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
device,
|
device,
|
||||||
timesteps=1000,
|
timesteps=1000,
|
||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
linear_start=0.0015,
|
linear_start=0.0015,
|
||||||
linear_end=0.0205,
|
linear_end=0.0205,
|
||||||
cosine_s=0.008,
|
cosine_s=0.008,
|
||||||
original_elbo_weight=0.,
|
original_elbo_weight=0.0,
|
||||||
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
||||||
l_simple_weight=1.,
|
l_simple_weight=1.0,
|
||||||
parameterization="eps", # all assuming fixed variance schedules
|
parameterization="eps", # all assuming fixed variance schedules
|
||||||
use_positional_encodings=False):
|
use_positional_encodings=False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.parameterization = parameterization
|
self.parameterization = parameterization
|
||||||
@ -53,64 +60,110 @@ class DDPM(nn.Module):
|
|||||||
self.original_elbo_weight = original_elbo_weight
|
self.original_elbo_weight = original_elbo_weight
|
||||||
self.l_simple_weight = l_simple_weight
|
self.l_simple_weight = l_simple_weight
|
||||||
|
|
||||||
self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,
|
self.register_schedule(
|
||||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
beta_schedule=beta_schedule,
|
||||||
|
timesteps=timesteps,
|
||||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
linear_start=linear_start,
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
||||||
betas = make_beta_schedule(self.device, beta_schedule, timesteps, linear_start=linear_start,
|
|
||||||
linear_end=linear_end,
|
linear_end=linear_end,
|
||||||
cosine_s=cosine_s)
|
cosine_s=cosine_s,
|
||||||
alphas = 1. - betas
|
)
|
||||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
||||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
def register_schedule(
|
||||||
|
self,
|
||||||
|
given_betas=None,
|
||||||
|
beta_schedule="linear",
|
||||||
|
timesteps=1000,
|
||||||
|
linear_start=1e-4,
|
||||||
|
linear_end=2e-2,
|
||||||
|
cosine_s=8e-3,
|
||||||
|
):
|
||||||
|
betas = make_beta_schedule(
|
||||||
|
self.device,
|
||||||
|
beta_schedule,
|
||||||
|
timesteps,
|
||||||
|
linear_start=linear_start,
|
||||||
|
linear_end=linear_end,
|
||||||
|
cosine_s=cosine_s,
|
||||||
|
)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
(timesteps,) = betas.shape
|
||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
assert (
|
||||||
|
alphas_cumprod.shape[0] == self.num_timesteps
|
||||||
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
self.register_buffer('betas', to_torch(betas))
|
self.register_buffer("betas", to_torch(betas))
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
self.register_buffer(
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
)
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
self.register_buffer(
|
||||||
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
||||||
|
)
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
||||||
1. - alphas_cumprod) + self.v_posterior * betas
|
1.0 - alphas_cumprod_prev
|
||||||
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
self.register_buffer(
|
||||||
self.register_buffer('posterior_mean_coef1', to_torch(
|
"posterior_log_variance_clipped",
|
||||||
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
||||||
self.register_buffer('posterior_mean_coef2', to_torch(
|
)
|
||||||
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
self.register_buffer(
|
||||||
|
"posterior_mean_coef1",
|
||||||
|
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"posterior_mean_coef2",
|
||||||
|
to_torch(
|
||||||
|
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if self.parameterization == "eps":
|
if self.parameterization == "eps":
|
||||||
lvlb_weights = self.betas ** 2 / (
|
lvlb_weights = self.betas ** 2 / (
|
||||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
2
|
||||||
|
* self.posterior_variance
|
||||||
|
* to_torch(alphas)
|
||||||
|
* (1 - self.alphas_cumprod)
|
||||||
|
)
|
||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
lvlb_weights = (
|
||||||
|
0.5
|
||||||
|
* np.sqrt(torch.Tensor(alphas_cumprod))
|
||||||
|
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("mu not supported")
|
raise NotImplementedError("mu not supported")
|
||||||
# TODO how to choose this term
|
# TODO how to choose this term
|
||||||
lvlb_weights[0] = lvlb_weights[1]
|
lvlb_weights[0] = lvlb_weights[1]
|
||||||
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
||||||
assert not torch.isnan(self.lvlb_weights).all()
|
assert not torch.isnan(self.lvlb_weights).all()
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusion(DDPM):
|
class LatentDiffusion(DDPM):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
diffusion_model,
|
diffusion_model,
|
||||||
device,
|
device,
|
||||||
cond_stage_key="image",
|
cond_stage_key="image",
|
||||||
@ -118,7 +171,9 @@ class LatentDiffusion(DDPM):
|
|||||||
concat_mode=True,
|
concat_mode=True,
|
||||||
scale_factor=1.0,
|
scale_factor=1.0,
|
||||||
scale_by_std=False,
|
scale_by_std=False,
|
||||||
*args, **kwargs):
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.num_timesteps_cond = 1
|
self.num_timesteps_cond = 1
|
||||||
self.scale_by_std = scale_by_std
|
self.scale_by_std = scale_by_std
|
||||||
super().__init__(device, *args, **kwargs)
|
super().__init__(device, *args, **kwargs)
|
||||||
@ -129,15 +184,31 @@ class LatentDiffusion(DDPM):
|
|||||||
self.num_downs = 2
|
self.num_downs = 2
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
def make_cond_schedule(self, ):
|
def make_cond_schedule(
|
||||||
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
self,
|
||||||
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
):
|
||||||
|
self.cond_ids = torch.full(
|
||||||
|
size=(self.num_timesteps,),
|
||||||
|
fill_value=self.num_timesteps - 1,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
ids = torch.round(
|
||||||
|
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
||||||
|
).long()
|
||||||
self.cond_ids[: self.num_timesteps_cond] = ids
|
self.cond_ids[: self.num_timesteps_cond] = ids
|
||||||
|
|
||||||
def register_schedule(self,
|
def register_schedule(
|
||||||
given_betas=None, beta_schedule="linear", timesteps=1000,
|
self,
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
given_betas=None,
|
||||||
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
beta_schedule="linear",
|
||||||
|
timesteps=1000,
|
||||||
|
linear_start=1e-4,
|
||||||
|
linear_end=2e-2,
|
||||||
|
cosine_s=8e-3,
|
||||||
|
):
|
||||||
|
super().register_schedule(
|
||||||
|
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
||||||
|
)
|
||||||
|
|
||||||
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
||||||
if self.shorten_cond_schedule:
|
if self.shorten_cond_schedule:
|
||||||
@ -160,37 +231,66 @@ class DDIMSampler(object):
|
|||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(
|
||||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||||
|
):
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
|
ddim_discr_method=ddim_discretize,
|
||||||
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
# array([1])
|
# array([1])
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
||||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
assert (
|
||||||
|
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||||
|
), "alphas have to be defined for each timestep"
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
self.register_buffer('betas', to_torch(self.model.betas))
|
self.register_buffer("betas", to_torch(self.model.betas))
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
self.register_buffer(
|
||||||
|
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
self.register_buffer(
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
)
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
self.register_buffer(
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
"sqrt_one_minus_alphas_cumprod",
|
||||||
|
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recipm1_alphas_cumprod",
|
||||||
|
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||||
|
)
|
||||||
|
|
||||||
# ddim sampling parameters
|
# ddim sampling parameters
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||||
|
alphacums=alphas_cumprod.cpu(),
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
eta=ddim_eta, verbose=verbose)
|
eta=ddim_eta,
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
verbose=verbose,
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
)
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||||
|
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||||
|
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
(1 - self.alphas_cumprod_prev)
|
||||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
/ (1 - self.alphas_cumprod)
|
||||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, steps, conditioning, batch_size, shape):
|
def sample(self, steps, conditioning, batch_size, shape):
|
||||||
@ -200,65 +300,108 @@ class DDIMSampler(object):
|
|||||||
size = (batch_size, C, H, W)
|
size = (batch_size, C, H, W)
|
||||||
|
|
||||||
# samples: 1,3,128,128
|
# samples: 1,3,128,128
|
||||||
return self.ddim_sampling(conditioning,
|
return self.ddim_sampling(
|
||||||
|
conditioning,
|
||||||
size,
|
size,
|
||||||
quantize_denoised=False,
|
quantize_denoised=False,
|
||||||
ddim_use_original_steps=False,
|
ddim_use_original_steps=False,
|
||||||
noise_dropout=0,
|
noise_dropout=0,
|
||||||
temperature=1.,
|
temperature=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim_sampling(self, cond, shape,
|
def ddim_sampling(
|
||||||
|
self,
|
||||||
|
cond,
|
||||||
|
shape,
|
||||||
ddim_use_original_steps=False,
|
ddim_use_original_steps=False,
|
||||||
quantize_denoised=False,
|
quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0.):
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device=device) # 用了
|
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps # 用了
|
timesteps = (
|
||||||
|
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
)
|
||||||
|
|
||||||
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
time_range = (
|
||||||
|
reversed(range(0, timesteps))
|
||||||
|
if ddim_use_original_steps
|
||||||
|
else np.flip(timesteps)
|
||||||
|
)
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
|
||||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
outs = self.p_sample_ddim(
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
img,
|
||||||
noise_dropout=noise_dropout)
|
cond,
|
||||||
|
ts,
|
||||||
|
index=index,
|
||||||
|
use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised,
|
||||||
|
temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
)
|
||||||
img, _ = outs
|
img, _ = outs
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(
|
||||||
temperature=1., noise_dropout=0.):
|
self,
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
t,
|
||||||
|
index,
|
||||||
|
repeat_noise=False,
|
||||||
|
use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
alphas_prev = (
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
self.model.alphas_cumprod_prev
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
if use_original_steps
|
||||||
|
else self.ddim_alphas_prev
|
||||||
|
)
|
||||||
|
sqrt_one_minus_alphas = (
|
||||||
|
self.model.sqrt_one_minus_alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sqrt_one_minus_alphas
|
||||||
|
)
|
||||||
|
sigmas = (
|
||||||
|
self.model.ddim_sigmas_for_original_num_steps
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sigmas
|
||||||
|
)
|
||||||
# select parameters corresponding to the currently considered timestep
|
# select parameters corresponding to the currently considered timestep
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
sqrt_one_minus_at = torch.full(
|
||||||
|
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||||
|
)
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
if quantize_denoised: # 没用
|
if quantize_denoised: # 没用
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
if noise_dropout > 0.: # 没用
|
if noise_dropout > 0.0: # 没用
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
return x_prev, pred_x0
|
return x_prev, pred_x0
|
||||||
@ -275,7 +418,8 @@ def load_jit_model(url, device):
|
|||||||
class LDM(InpaintModel):
|
class LDM(InpaintModel):
|
||||||
pad_mod = 32
|
pad_mod = 32
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device, fp16: bool = True):
|
||||||
|
self.fp16 = fp16
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@ -283,6 +427,10 @@ class LDM(InpaintModel):
|
|||||||
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
||||||
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
||||||
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
||||||
|
if self.fp16 and "cuda" in str(device):
|
||||||
|
self.diffusion_model = self.diffusion_model.half()
|
||||||
|
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
||||||
|
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
||||||
|
|
||||||
model = LatentDiffusion(self.diffusion_model, device)
|
model = LatentDiffusion(self.diffusion_model, device)
|
||||||
self.sampler = DDIMSampler(model)
|
self.sampler = DDIMSampler(model)
|
||||||
@ -296,6 +444,7 @@ class LDM(InpaintModel):
|
|||||||
]
|
]
|
||||||
return all([os.path.exists(it) for it in model_paths])
|
return all([os.path.exists(it) for it in model_paths])
|
||||||
|
|
||||||
|
@torch.cuda.amp.autocast()
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""
|
"""
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
@ -321,16 +470,20 @@ class LDM(InpaintModel):
|
|||||||
masked_image = self._norm(masked_image)
|
masked_image = self._norm(masked_image)
|
||||||
|
|
||||||
c = self.cond_stage_model_encode(masked_image)
|
c = self.cond_stage_model_encode(masked_image)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
||||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||||
|
|
||||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||||
samples_ddim = self.sampler.sample(steps=steps,
|
samples_ddim = self.sampler.sample(
|
||||||
conditioning=c,
|
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
||||||
batch_size=c.shape[0],
|
)
|
||||||
shape=shape)
|
torch.cuda.empty_cache()
|
||||||
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
|
x_samples_ddim = self.cond_stage_model_decode(
|
||||||
|
samples_ddim
|
||||||
|
) # samples_ddim: 1, 3, 128, 128 float32
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
Loading…
Reference in New Issue
Block a user