fix model init

This commit is contained in:
Qing 2022-09-22 22:45:24 +08:00
parent 8d65195e8a
commit 5bdc5c1526
2 changed files with 18 additions and 13 deletions

View File

@ -11,7 +11,12 @@ from lama_cleaner.schema import Config, LDMSampler
torch.manual_seed(42)
import torch.nn as nn
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url, load_jit_model
from lama_cleaner.helper import (
download_model,
norm_img,
get_cache_path_by_url,
load_jit_model,
)
from lama_cleaner.model.utils import (
make_beta_schedule,
timestep_embedding,
@ -139,7 +144,7 @@ class DDPM(nn.Module):
)
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
@ -222,7 +227,7 @@ class LatentDiffusion(DDPM):
class LDM(InpaintModel):
pad_mod = 32
def __init__(self, device, fp16: bool = True):
def __init__(self, device, fp16: bool = True, **kwargs):
self.fp16 = fp16
super().__init__(device)
self.device = device

View File

@ -206,7 +206,7 @@ class ZITS(InpaintModel):
pad_mod = 32
pad_to_square = True
def __init__(self, device):
def __init__(self, device, **kwargs):
"""
Args: