fix model init

This commit is contained in:
Qing 2022-09-22 22:45:24 +08:00
parent 8d65195e8a
commit 5bdc5c1526
2 changed files with 18 additions and 13 deletions

View File

@ -11,7 +11,12 @@ from lama_cleaner.schema import Config, LDMSampler
torch.manual_seed(42) torch.manual_seed(42)
import torch.nn as nn import torch.nn as nn
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url, load_jit_model from lama_cleaner.helper import (
download_model,
norm_img,
get_cache_path_by_url,
load_jit_model,
)
from lama_cleaner.model.utils import ( from lama_cleaner.model.utils import (
make_beta_schedule, make_beta_schedule,
timestep_embedding, timestep_embedding,
@ -92,7 +97,7 @@ class DDPM(nn.Module):
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
assert ( assert (
alphas_cumprod.shape[0] == self.num_timesteps alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep" ), "alphas have to be defined for each timestep"
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
@ -118,7 +123,7 @@ class DDPM(nn.Module):
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * ( posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev 1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance)) self.register_buffer("posterior_variance", to_torch(posterior_variance))
@ -139,17 +144,17 @@ class DDPM(nn.Module):
) )
if self.parameterization == "eps": if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / ( lvlb_weights = self.betas**2 / (
2 2
* self.posterior_variance * self.posterior_variance
* to_torch(alphas) * to_torch(alphas)
* (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod)
) )
elif self.parameterization == "x0": elif self.parameterization == "x0":
lvlb_weights = ( lvlb_weights = (
0.5 0.5
* np.sqrt(torch.Tensor(alphas_cumprod)) * np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
) )
else: else:
raise NotImplementedError("mu not supported") raise NotImplementedError("mu not supported")
@ -222,7 +227,7 @@ class LatentDiffusion(DDPM):
class LDM(InpaintModel): class LDM(InpaintModel):
pad_mod = 32 pad_mod = 32
def __init__(self, device, fp16: bool = True): def __init__(self, device, fp16: bool = True, **kwargs):
self.fp16 = fp16 self.fp16 = fp16
super().__init__(device) super().__init__(device)
self.device = device self.device = device

View File

@ -206,7 +206,7 @@ class ZITS(InpaintModel):
pad_mod = 32 pad_mod = 32
pad_to_square = True pad_to_square = True
def __init__(self, device): def __init__(self, device, **kwargs):
""" """
Args: Args: