diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index a45a7e9..27f6d5a 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -11,7 +11,12 @@ from lama_cleaner.schema import Config, LDMSampler torch.manual_seed(42) import torch.nn as nn -from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url, load_jit_model +from lama_cleaner.helper import ( + download_model, + norm_img, + get_cache_path_by_url, + load_jit_model, +) from lama_cleaner.model.utils import ( make_beta_schedule, timestep_embedding, @@ -92,7 +97,7 @@ class DDPM(nn.Module): self.linear_start = linear_start self.linear_end = linear_end assert ( - alphas_cumprod.shape[0] == self.num_timesteps + alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) @@ -118,7 +123,7 @@ class DDPM(nn.Module): # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev + 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer("posterior_variance", to_torch(posterior_variance)) @@ -139,17 +144,17 @@ class DDPM(nn.Module): ) if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) ) elif self.parameterization == "x0": lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) else: raise NotImplementedError("mu not supported") @@ -222,7 +227,7 @@ class LatentDiffusion(DDPM): class LDM(InpaintModel): pad_mod = 32 - def __init__(self, device, fp16: bool = True): + def __init__(self, device, fp16: bool = True, **kwargs): self.fp16 = fp16 super().__init__(device) self.device = device diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index db59c63..597839c 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -206,7 +206,7 @@ class ZITS(InpaintModel): pad_mod = 32 pad_to_square = True - def __init__(self, device): + def __init__(self, device, **kwargs): """ Args: