add sd1.5

This commit is contained in:
Qing 2022-10-20 21:01:14 +08:00
parent d892d9166f
commit 6ccb6cd291
6 changed files with 64 additions and 424 deletions

View File

@ -7,8 +7,6 @@ import torch
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
from loguru import logger from loguru import logger
from lama_cleaner.helper import norm_img
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config, SDSampler from lama_cleaner.schema import Config, SDSampler
@ -38,12 +36,22 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask) # mask = torch.from_numpy(mask)
# return mask # return mask
class CPUTextEncoderWrapper:
def __init__(self, text_encoder):
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
def __call__(self, x):
input_device = x.device
return [self.text_encoder(x.to(self.text_encoder.device))[0].to(input_device)]
class SD(InpaintModel): class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 pad_mod = 8 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
min_size = 512 min_size = 512
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from .sd_pipeline import StableDiffusionInpaintPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
model_kwargs = {"local_files_only": kwargs['sd_run_local']} model_kwargs = {"local_files_only": kwargs['sd_run_local']}
if kwargs['sd_disable_nsfw']: if kwargs['sd_disable_nsfw']:
@ -65,8 +73,7 @@ class SD(InpaintModel):
if kwargs['sd_cpu_textencoder']: if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU") logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True) self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder)
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
@ -99,7 +106,6 @@ class SD(InpaintModel):
) )
elif config.sd_sampler == SDSampler.pndm: elif config.sd_sampler == SDSampler.pndm:
PNDM_kwargs = { PNDM_kwargs = {
"tensor_format": "pt",
"beta_schedule": "scaled_linear", "beta_schedule": "scaled_linear",
"beta_start": 0.00085, "beta_start": 0.00085,
"beta_end": 0.012, "beta_end": 0.012,
@ -124,15 +130,19 @@ class SD(InpaintModel):
k = 2 * config.sd_mask_blur + 1 k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
_kwargs = {
self.image_key: PIL.Image.fromarray(image),
}
output = self.model( output = self.model(
prompt=config.prompt, prompt=config.prompt,
init_image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
strength=config.sd_strength, strength=config.sd_strength,
num_inference_steps=config.sd_steps, num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np.array", output_type="np.array",
callback=self.callback, callback=self.callback,
**_kwargs
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")
@ -185,7 +195,9 @@ class SD(InpaintModel):
class SD14(SD): class SD14(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-4" model_id_or_path = "CompVis/stable-diffusion-v1-4"
image_key = "init_image"
class SD15(SD): class SD15(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-5" model_id_or_path = "runwayml/stable-diffusion-inpainting"
image_key = "image"

View File

@ -1,406 +0,0 @@
import inspect
from typing import List, Optional, Union, Callable
import numpy as np
import torch
import PIL
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
from diffusers.utils import logging, deprecate
from diffusers.configuration_utils import FrozenDict
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__)
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
class StableDiffusionInpaintPipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_device = self.text_encoder.device
text_embeddings = self.text_encoder(text_input_ids.to(text_encoder_device))[0].to(self.device)
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device))[0].to(self.device)
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in tqdm(enumerate(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -2,12 +2,12 @@ from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.lama import LaMa from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.mat import MAT from lama_cleaner.model.mat import MAT
from lama_cleaner.model.sd import SD14 from lama_cleaner.model.sd import SD14, SD15
from lama_cleaner.model.zits import ZITS from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14, "cv2": OpenCV2} models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14, "sd1.5": SD15, "cv2": OpenCV2}
class ModelManager: class ModelManager:

View File

@ -10,7 +10,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--model", "--model",
default="lama", default="lama",
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.4", "cv2"], choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.4", "sd1.5", "cv2"],
) )
parser.add_argument( parser.add_argument(
"--hf_access_token", "--hf_access_token",

View File

@ -159,10 +159,10 @@ def test_fcf(strategy):
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm]) @pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
def test_sd(strategy, sampler): def test_sd(strategy, sampler):
def callback(step: int): def callback(i, t, latents):
print(f"sd_step_{step}") print(f"sd_step_{i}")
sd_steps = 50 sd_steps = 50
model = ModelManager(name="sd1.4", model = ModelManager(name="sd1.4",
@ -197,8 +197,8 @@ def test_sd(strategy, sampler):
@pytest.mark.parametrize("disable_nsfw", [True, False]) @pytest.mark.parametrize("disable_nsfw", [True, False])
@pytest.mark.parametrize("cpu_textencoder", [True, False]) @pytest.mark.parametrize("cpu_textencoder", [True, False])
def test_sd_run_local(strategy, sampler, disable_nsfw, cpu_textencoder): def test_sd_run_local(strategy, sampler, disable_nsfw, cpu_textencoder):
def callback(step: int): def callback(i, t, latents):
print(f"sd_step_{step}") print(f"sd_step_{i}")
sd_steps = 50 sd_steps = 50
model = ModelManager( model = ModelManager(
@ -222,6 +222,40 @@ def test_sd_run_local(strategy, sampler, disable_nsfw, cpu_textencoder):
) )
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
def test_runway_sd_1_5(strategy, sampler):
def callback(i, t, latents):
print(f"sd_step_{i}")
sd_steps = 20
model = ModelManager(name="sd1.5",
device=device,
hf_access_token=None,
sd_run_local=True,
sd_disable_nsfw=True,
sd_cpu_textencoder=True,
callback=callback)
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler
assert_equal(
model,
cfg,
f"runway_sd_{strategy.capitalize()}_{sampler}_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
assert_equal(
model,
cfg,
f"runway_sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
) )

View File

@ -10,5 +10,5 @@ pytest
yacs yacs
markupsafe==2.0.1 markupsafe==2.0.1
scikit-image==0.19.3 scikit-image==0.19.3
diffusers==0.5.1 diffusers==0.6.0
transformers==4.21.0 transformers==4.21.0