fix model init
This commit is contained in:
parent
8d65195e8a
commit
5bdc5c1526
@ -11,7 +11,12 @@ from lama_cleaner.schema import Config, LDMSampler
|
|||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url, load_jit_model
|
from lama_cleaner.helper import (
|
||||||
|
download_model,
|
||||||
|
norm_img,
|
||||||
|
get_cache_path_by_url,
|
||||||
|
load_jit_model,
|
||||||
|
)
|
||||||
from lama_cleaner.model.utils import (
|
from lama_cleaner.model.utils import (
|
||||||
make_beta_schedule,
|
make_beta_schedule,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
@ -92,7 +97,7 @@ class DDPM(nn.Module):
|
|||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
assert (
|
assert (
|
||||||
alphas_cumprod.shape[0] == self.num_timesteps
|
alphas_cumprod.shape[0] == self.num_timesteps
|
||||||
), "alphas have to be defined for each timestep"
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||||
@ -118,7 +123,7 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = (1 - self.v_posterior) * betas * (
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
||||||
1.0 - alphas_cumprod_prev
|
1.0 - alphas_cumprod_prev
|
||||||
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||||
@ -139,17 +144,17 @@ class DDPM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.parameterization == "eps":
|
if self.parameterization == "eps":
|
||||||
lvlb_weights = self.betas ** 2 / (
|
lvlb_weights = self.betas**2 / (
|
||||||
2
|
2
|
||||||
* self.posterior_variance
|
* self.posterior_variance
|
||||||
* to_torch(alphas)
|
* to_torch(alphas)
|
||||||
* (1 - self.alphas_cumprod)
|
* (1 - self.alphas_cumprod)
|
||||||
)
|
)
|
||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
lvlb_weights = (
|
lvlb_weights = (
|
||||||
0.5
|
0.5
|
||||||
* np.sqrt(torch.Tensor(alphas_cumprod))
|
* np.sqrt(torch.Tensor(alphas_cumprod))
|
||||||
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("mu not supported")
|
raise NotImplementedError("mu not supported")
|
||||||
@ -222,7 +227,7 @@ class LatentDiffusion(DDPM):
|
|||||||
class LDM(InpaintModel):
|
class LDM(InpaintModel):
|
||||||
pad_mod = 32
|
pad_mod = 32
|
||||||
|
|
||||||
def __init__(self, device, fp16: bool = True):
|
def __init__(self, device, fp16: bool = True, **kwargs):
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -206,7 +206,7 @@ class ZITS(InpaintModel):
|
|||||||
pad_mod = 32
|
pad_mod = 32
|
||||||
pad_to_square = True
|
pad_to_square = True
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
Loading…
Reference in New Issue
Block a user