sd make change sampler work

This commit is contained in:
Qing 2022-09-22 12:38:32 +08:00
parent 047474ab84
commit e1fb0030d1
3 changed files with 46 additions and 19 deletions

View File

@ -4,12 +4,13 @@ import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import PNDMScheduler, DDIMScheduler
from loguru import logger from loguru import logger
from lama_cleaner.helper import norm_img from lama_cleaner.helper import norm_img
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.schema import Config, SDSampler
# #
@ -43,13 +44,12 @@ class SD(InpaintModel):
min_size = 512 min_size = 512
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
# return
from .sd_pipeline import StableDiffusionInpaintPipeline from .sd_pipeline import StableDiffusionInpaintPipeline
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16", revision="fp16" if torch.cuda.is_available() else 'main',
torch_dtype=torch.float16, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
) )
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
@ -59,7 +59,6 @@ class SD(InpaintModel):
@torch.cuda.amp.autocast() @torch.cuda.amp.autocast()
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
# return image
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint mask: [H, W, 1] 255 means area to repaint
@ -76,9 +75,30 @@ class SD(InpaintModel):
# #
# image = torch.from_numpy(image).unsqueeze(0).to(self.device) # image = torch.from_numpy(image).unsqueeze(0).to(self.device)
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) # mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
# import time
# time.sleep(2) if config.sd_sampler == SDSampler.ddim:
# return image scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
elif config.sd_sampler == SDSampler.pndm:
PNDM_kwargs = {
"tensor_format": "pt",
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"num_train_timesteps": 1000,
"skip_prk_steps": True
}
scheduler = PNDMScheduler(**PNDM_kwargs)
else:
raise ValueError(config.sd_sampler)
self.model.scheduler = scheduler
seed = config.sd_seed seed = config.sd_seed
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)

View File

@ -14,6 +14,11 @@ class LDMSampler(str, Enum):
plms = 'plms' plms = 'plms'
class SDSampler(str, Enum):
ddim = 'ddim'
pndm = 'pndm'
class Config(BaseModel): class Config(BaseModel):
ldm_steps: int ldm_steps: int
ldm_sampler: str = LDMSampler.plms ldm_sampler: str = LDMSampler.plms
@ -35,6 +40,6 @@ class Config(BaseModel):
sd_strength: float = 0.75 sd_strength: float = 0.75
sd_steps: int = 50 sd_steps: int = 50
sd_guidance_scale: float = 7.5 sd_guidance_scale: float = 7.5
sd_sampler: str = 'ddim' # ddim/pndm sd_sampler: str = SDSampler.ddim
# -1 mean random seed # -1 mean random seed
sd_seed: int = 42 sd_seed: int = 42

View File

@ -6,7 +6,7 @@ import pytest
import torch import torch
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -155,25 +155,27 @@ def test_fcf(strategy):
fy=2 fy=2
) )
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_sd(strategy, capfd): @pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm])
def test_sd(strategy, sampler, capfd):
def callback(step: int): def callback(step: int):
print(f"sd_step_{step}") print(f"sd_step_{step}")
sd_steps = 2 sd_steps = 50
model = ModelManager(name="sd", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], callbacks=[callback]) model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'],
callbacks=[callback])
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler
assert_equal( assert_equal(
model, model,
cfg, cfg,
f"sd_{strategy.capitalize()}_result.png", f"sd_{strategy.capitalize()}_{sampler}_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=0.5,
fy=0.5
) )
captured = capfd.readouterr() # captured = capfd.readouterr()
for i in range(sd_steps): # for i in range(sd_steps):
assert f'sd_step_{i}' in captured.out # assert f'sd_step_{i}' in captured.out