From e4a6c91f4a44588574d866c2c93933612b6f5c37 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 27 May 2022 23:00:07 +0800 Subject: [PATCH] enable fp16 for ldm by default --- lama_cleaner/model/ldm.py | 385 ++++++++++++++++++++++++++------------ 1 file changed, 269 insertions(+), 116 deletions(-) diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index fc7a0e0..2fea7c0 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -11,8 +11,13 @@ torch.manual_seed(42) import torch.nn as nn from tqdm import tqdm from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url -from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \ - timestep_embedding +from lama_cleaner.model.utils import ( + make_beta_schedule, + make_ddim_timesteps, + make_ddim_sampling_parameters, + noise_like, + timestep_embedding, +) LDM_ENCODE_MODEL_URL = os.environ.get( "LDM_ENCODE_MODEL_URL", @@ -32,18 +37,20 @@ LDM_DIFFUSION_MODEL_URL = os.environ.get( class DDPM(nn.Module): # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - device, - timesteps=1000, - beta_schedule="linear", - linear_start=0.0015, - linear_end=0.0205, - cosine_s=0.008, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - parameterization="eps", # all assuming fixed variance schedules - use_positional_encodings=False): + def __init__( + self, + device, + timesteps=1000, + beta_schedule="linear", + linear_start=0.0015, + linear_end=0.0205, + cosine_s=0.008, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + parameterization="eps", # all assuming fixed variance schedules + use_positional_encodings=False, + ): super().__init__() self.device = device self.parameterization = parameterization @@ -53,72 +60,120 @@ class DDPM(nn.Module): self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight - self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + self.register_schedule( + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(self.device, beta_schedule, timesteps, linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + self.device, + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) if self.parameterization == "eps": lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) else: raise NotImplementedError("mu not supported") # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() class LatentDiffusion(DDPM): - def __init__(self, - diffusion_model, - device, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - scale_factor=1.0, - scale_by_std=False, - *args, **kwargs): + def __init__( + self, + diffusion_model, + device, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + scale_factor=1.0, + scale_by_std=False, + *args, + **kwargs, + ): self.num_timesteps_cond = 1 self.scale_by_std = scale_by_std super().__init__(device, *args, **kwargs) @@ -129,15 +184,31 @@ class LatentDiffusion(DDPM): self.num_downs = 2 self.scale_factor = scale_factor - def make_cond_schedule(self, ): - self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) - ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids - def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: @@ -160,37 +231,66 @@ class DDIMSampler(object): def register_buffer(self, name, attr): setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - # array([1]) - num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + # array([1]) + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) @torch.no_grad() def sample(self, steps, conditioning, batch_size, shape): @@ -200,65 +300,108 @@ class DDIMSampler(object): size = (batch_size, C, H, W) # samples: 1,3,128,128 - return self.ddim_sampling(conditioning, - size, - quantize_denoised=False, - ddim_use_original_steps=False, - noise_dropout=0, - temperature=1., - ) + return self.ddim_sampling( + conditioning, + size, + quantize_denoised=False, + ddim_use_original_steps=False, + noise_dropout=0, + temperature=1.0, + ) @torch.no_grad() - def ddim_sampling(self, cond, shape, - ddim_use_original_steps=False, - quantize_denoised=False, - temperature=1., noise_dropout=0.): + def ddim_sampling( + self, + cond, + shape, + ddim_use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): device = self.model.betas.device b = shape[0] - img = torch.randn(shape, device=device) # 用了 - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps # 用了 + img = torch.randn(shape, device=device, dtype=cond.dtype) + timesteps = ( + self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + ) - time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] logger.info(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + ) img, _ = outs return img @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0.): + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): b, *_, device = *x.shape, x.device e_t = self.model.apply_model(x, t, c) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: # 没用 pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: # 没用 + if noise_dropout > 0.0: # 没用 noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @@ -275,7 +418,8 @@ def load_jit_model(url, device): class LDM(InpaintModel): pad_mod = 32 - def __init__(self, device): + def __init__(self, device, fp16: bool = True): + self.fp16 = fp16 super().__init__(device) self.device = device @@ -283,6 +427,10 @@ class LDM(InpaintModel): self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) + if self.fp16 and "cuda" in str(device): + self.diffusion_model = self.diffusion_model.half() + self.cond_stage_model_decode = self.cond_stage_model_decode.half() + self.cond_stage_model_encode = self.cond_stage_model_encode.half() model = LatentDiffusion(self.diffusion_model, device) self.sampler = DDIMSampler(model) @@ -296,6 +444,7 @@ class LDM(InpaintModel): ] return all([os.path.exists(it) for it in model_paths]) + @torch.cuda.amp.autocast() def forward(self, image, mask, config: Config): """ image: [H, W, C] RGB @@ -321,16 +470,20 @@ class LDM(InpaintModel): masked_image = self._norm(masked_image) c = self.cond_stage_model_encode(masked_image) + torch.cuda.empty_cache() cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128 c = torch.cat((c, cc), dim=1) # 1,4,128,128 shape = (c.shape[1] - 1,) + c.shape[2:] - samples_ddim = self.sampler.sample(steps=steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape) - x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32 + samples_ddim = self.sampler.sample( + steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape + ) + torch.cuda.empty_cache() + x_samples_ddim = self.cond_stage_model_decode( + samples_ddim + ) # samples_ddim: 1, 3, 128, 128 float32 + torch.cuda.empty_cache() # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)