fix model init
This commit is contained in:
parent
8d65195e8a
commit
5bdc5c1526
@ -11,7 +11,12 @@ from lama_cleaner.schema import Config, LDMSampler
|
||||
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url, load_jit_model
|
||||
from lama_cleaner.helper import (
|
||||
download_model,
|
||||
norm_img,
|
||||
get_cache_path_by_url,
|
||||
load_jit_model,
|
||||
)
|
||||
from lama_cleaner.model.utils import (
|
||||
make_beta_schedule,
|
||||
timestep_embedding,
|
||||
@ -139,7 +144,7 @@ class DDPM(nn.Module):
|
||||
)
|
||||
|
||||
if self.parameterization == "eps":
|
||||
lvlb_weights = self.betas ** 2 / (
|
||||
lvlb_weights = self.betas**2 / (
|
||||
2
|
||||
* self.posterior_variance
|
||||
* to_torch(alphas)
|
||||
@ -222,7 +227,7 @@ class LatentDiffusion(DDPM):
|
||||
class LDM(InpaintModel):
|
||||
pad_mod = 32
|
||||
|
||||
def __init__(self, device, fp16: bool = True):
|
||||
def __init__(self, device, fp16: bool = True, **kwargs):
|
||||
self.fp16 = fp16
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
@ -206,7 +206,7 @@ class ZITS(InpaintModel):
|
||||
pad_mod = 32
|
||||
pad_to_square = True
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
Loading…
Reference in New Issue
Block a user