update
This commit is contained in:
parent
f27fc51e34
commit
141936a937
@ -84,8 +84,13 @@ SD2_CONTROLNET_CHOICES = [
|
|||||||
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
|
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
|
||||||
SDXL_CONTROLNET_CHOICES = [
|
SDXL_CONTROLNET_CHOICES = [
|
||||||
"thibaud/controlnet-openpose-sdxl-1.0",
|
"thibaud/controlnet-openpose-sdxl-1.0",
|
||||||
|
"destitech/controlnet-inpaint-dreamer-sdxl"
|
||||||
"diffusers/controlnet-canny-sdxl-1.0",
|
"diffusers/controlnet-canny-sdxl-1.0",
|
||||||
|
"diffusers/controlnet-canny-sdxl-1.0-mid",
|
||||||
|
"diffusers/controlnet-canny-sdxl-1.0-small"
|
||||||
"diffusers/controlnet-depth-sdxl-1.0",
|
"diffusers/controlnet-depth-sdxl-1.0",
|
||||||
|
"diffusers/controlnet-depth-sdxl-1.0-mid",
|
||||||
|
"diffusers/controlnet-depth-sdxl-1.0-small",
|
||||||
]
|
]
|
||||||
|
|
||||||
SD_LOCAL_MODEL_HELP = """
|
SD_LOCAL_MODEL_HELP = """
|
||||||
|
@ -5,7 +5,7 @@ from typing import List
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
|
||||||
from lama_cleaner.schema import (
|
from lama_cleaner.schema import (
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -117,9 +117,7 @@ def scan_models() -> List[ModelInfo]:
|
|||||||
|
|
||||||
available_models = []
|
available_models = []
|
||||||
available_models.extend(scan_inpaint_models())
|
available_models.extend(scan_inpaint_models())
|
||||||
available_models.extend(
|
available_models.extend(scan_single_file_diffusion_models(DEFAULT_MODEL_DIR))
|
||||||
scan_single_file_diffusion_models(os.environ["XDG_CACHE_HOME"])
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_dir = Path(DIFFUSERS_CACHE)
|
cache_dir = Path(DIFFUSERS_CACHE)
|
||||||
diffusers_model_names = []
|
diffusers_model_names = []
|
||||||
|
@ -279,15 +279,12 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
"""
|
"""
|
||||||
# boxes = boxes_from_mask(mask)
|
# boxes = boxes_from_mask(mask)
|
||||||
if config.use_croper:
|
if config.use_croper:
|
||||||
if config.croper_is_outpainting:
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||||
inpaint_result = self._do_outpainting(image, config)
|
|
||||||
else:
|
|
||||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(
|
|
||||||
image, mask, config
|
|
||||||
)
|
|
||||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||||
inpaint_result = image[:, :, ::-1]
|
inpaint_result = image[:, :, ::-1]
|
||||||
inpaint_result[t:b, l:r, :] = crop_image
|
inpaint_result[t:b, l:r, :] = crop_image
|
||||||
|
elif config.use_extender:
|
||||||
|
inpaint_result = self._do_outpainting(image, config)
|
||||||
else:
|
else:
|
||||||
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||||
|
|
||||||
@ -297,10 +294,10 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
||||||
# 从 image 中 crop 出 outpainting 区域
|
# 从 image 中 crop 出 outpainting 区域
|
||||||
image_h, image_w = image.shape[:2]
|
image_h, image_w = image.shape[:2]
|
||||||
cropper_l = config.croper_x
|
cropper_l = config.extender_x
|
||||||
cropper_t = config.croper_y
|
cropper_t = config.extender_y
|
||||||
cropper_r = config.croper_x + config.croper_width
|
cropper_r = config.extender_x + config.extender_width
|
||||||
cropper_b = config.croper_y + config.croper_height
|
cropper_b = config.extender_y + config.extender_height
|
||||||
image_l = 0
|
image_l = 0
|
||||||
image_t = 0
|
image_t = 0
|
||||||
image_r = image_w
|
image_r = image_w
|
||||||
@ -356,8 +353,8 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
)[:, :, ::-1]
|
)[:, :, ::-1]
|
||||||
|
|
||||||
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
|
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
|
||||||
paste_t = 0 if config.croper_y < 0 else config.croper_y
|
paste_t = 0 if config.extender_y < 0 else config.extender_y
|
||||||
paste_l = 0 if config.croper_x < 0 else config.croper_x
|
paste_l = 0 if config.extender_x < 0 else config.extender_x
|
||||||
|
|
||||||
outpainting_image[
|
outpainting_image[
|
||||||
paste_t : paste_t + expanded_cropped_result_image.shape[0],
|
paste_t : paste_t + expanded_cropped_result_image.shape[0],
|
||||||
@ -397,8 +394,6 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
def set_scheduler(self, config: Config):
|
def set_scheduler(self, config: Config):
|
||||||
scheduler_config = self.model.scheduler.config
|
scheduler_config = self.model.scheduler.config
|
||||||
sd_sampler = config.sd_sampler
|
sd_sampler = config.sd_sampler
|
||||||
if config.sd_lcm_lora:
|
|
||||||
sd_sampler = SDSampler.lcm
|
|
||||||
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
||||||
self.model.scheduler = scheduler
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
|
@ -31,6 +31,20 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
pad_mod = 8
|
pad_mod = 8
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lcm_lora_id(self):
|
||||||
|
if self.model_info.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
]:
|
||||||
|
return "latent-consistency/lcm-lora-sdv1-5"
|
||||||
|
if self.model_info.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]:
|
||||||
|
return "latent-consistency/lcm-lora-sdxl"
|
||||||
|
raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
|
||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
fp16 = not kwargs.get("no_half", False)
|
fp16 = not kwargs.get("no_half", False)
|
||||||
model_info: ModelInfo = kwargs["model_info"]
|
model_info: ModelInfo = kwargs["model_info"]
|
||||||
@ -72,7 +86,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
controlnet = ControlNetModel.from_pretrained(
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
sd_controlnet_method, torch_dtype=torch_dtype
|
sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True
|
||||||
)
|
)
|
||||||
if model_info.is_single_file_diffusers:
|
if model_info.is_single_file_diffusers:
|
||||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
@ -81,7 +95,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
model_kwargs["num_in_channels"] = 9
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
self.model = PipeClass.from_single_file(
|
self.model = PipeClass.from_single_file(
|
||||||
model_info.path, controlnet=controlnet
|
model_info.path, controlnet=controlnet, **model_kwargs
|
||||||
).to(torch_dtype)
|
).to(torch_dtype)
|
||||||
else:
|
else:
|
||||||
self.model = PipeClass.from_pretrained(
|
self.model = PipeClass.from_pretrained(
|
||||||
|
@ -39,7 +39,7 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
|
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
|
@ -16,6 +16,7 @@ from diffusers import (
|
|||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
|
LCMScheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
from lama_cleaner.schema import SDSampler
|
from lama_cleaner.schema import SDSampler
|
||||||
@ -939,5 +940,7 @@ def get_scheduler(sd_sampler, scheduler_config):
|
|||||||
return DPMSolverMultistepScheduler.from_config(scheduler_config)
|
return DPMSolverMultistepScheduler.from_config(scheduler_config)
|
||||||
elif sd_sampler == SDSampler.uni_pc:
|
elif sd_sampler == SDSampler.uni_pc:
|
||||||
return UniPCMultistepScheduler.from_config(scheduler_config)
|
return UniPCMultistepScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.lcm:
|
||||||
|
return LCMScheduler.from_config(scheduler_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(sd_sampler)
|
raise ValueError(sd_sampler)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import gc
|
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD
|
||||||
from lama_cleaner.download import scan_models
|
from lama_cleaner.download import scan_models
|
||||||
from lama_cleaner.helper import switch_mps_device
|
from lama_cleaner.helper import switch_mps_device
|
||||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||||
@ -18,6 +18,11 @@ class ModelManager:
|
|||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.available_models: Dict[str, ModelInfo] = {}
|
self.available_models: Dict[str, ModelInfo] = {}
|
||||||
self.scan_models()
|
self.scan_models()
|
||||||
|
|
||||||
|
self.sd_controlnet = kwargs.get("sd_controlnet", False)
|
||||||
|
self.sd_controlnet_method = kwargs.get(
|
||||||
|
"sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD
|
||||||
|
)
|
||||||
self.model = self.init_model(name, device, **kwargs)
|
self.model = self.init_model(name, device, **kwargs)
|
||||||
|
|
||||||
def init_model(self, name: str, device, **kwargs):
|
def init_model(self, name: str, device, **kwargs):
|
||||||
@ -28,12 +33,17 @@ class ModelManager:
|
|||||||
raise NotImplementedError(f"Unsupported model: {name}")
|
raise NotImplementedError(f"Unsupported model: {name}")
|
||||||
|
|
||||||
model_info = self.available_models[name]
|
model_info = self.available_models[name]
|
||||||
kwargs = {**kwargs, "model_info": model_info}
|
kwargs = {
|
||||||
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
|
**kwargs,
|
||||||
|
"model_info": model_info,
|
||||||
|
"sd_controlnet": self.sd_controlnet,
|
||||||
|
"sd_controlnet_method": self.sd_controlnet_method,
|
||||||
|
}
|
||||||
|
|
||||||
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
||||||
return models[name](device, **kwargs)
|
return models[name](device, **kwargs)
|
||||||
|
|
||||||
if sd_controlnet_enabled:
|
if self.sd_controlnet:
|
||||||
return ControlNet(device, **kwargs)
|
return ControlNet(device, **kwargs)
|
||||||
else:
|
else:
|
||||||
if model_info.model_type in [
|
if model_info.model_type in [
|
||||||
@ -51,7 +61,7 @@ class ModelManager:
|
|||||||
raise NotImplementedError(f"Unsupported model: {name}")
|
raise NotImplementedError(f"Unsupported model: {name}")
|
||||||
|
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
self.switch_controlnet_method(config)
|
||||||
self.enable_disable_freeu(config)
|
self.enable_disable_freeu(config)
|
||||||
self.enable_disable_lcm_lora(config)
|
self.enable_disable_lcm_lora(config)
|
||||||
return self.model(image, mask, config)
|
return self.model(image, mask, config)
|
||||||
@ -66,40 +76,56 @@ class ModelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
old_name = self.name
|
old_name = self.name
|
||||||
|
old_sd_controlnet_method = self.sd_controlnet_method
|
||||||
self.name = new_name
|
self.name = new_name
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.available_models[new_name].support_controlnet
|
||||||
|
and self.sd_controlnet_method
|
||||||
|
not in self.available_models[new_name].controlnets
|
||||||
|
):
|
||||||
|
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
|
||||||
try:
|
try:
|
||||||
if torch.cuda.memory_allocated() > 0:
|
|
||||||
# Clear current loaded model from memory
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
del self.model
|
del self.model
|
||||||
gc.collect()
|
torch_gc()
|
||||||
|
|
||||||
self.model = self.init_model(
|
self.model = self.init_model(
|
||||||
new_name, switch_mps_device(new_name, self.device), **self.kwargs
|
new_name, switch_mps_device(new_name, self.device), **self.kwargs
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.name = old_name
|
self.name = old_name
|
||||||
|
self.sd_controlnet_method = old_sd_controlnet_method
|
||||||
|
logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
|
||||||
|
self.model = self.init_model(
|
||||||
|
old_name, switch_mps_device(old_name, self.device), **self.kwargs
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def switch_controlnet_method(self, control_method: str):
|
def switch_controlnet_method(self, config):
|
||||||
if not self.kwargs.get("sd_controlnet"):
|
|
||||||
return
|
|
||||||
if self.kwargs["sd_controlnet_method"] == control_method:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.available_models[self.name].support_controlnet:
|
if not self.available_models[self.name].support_controlnet:
|
||||||
return
|
return
|
||||||
|
|
||||||
del self.model
|
if self.sd_controlnet != config.controlnet_enabled or (
|
||||||
torch_gc()
|
self.sd_controlnet and self.sd_controlnet_method != config.controlnet_method
|
||||||
|
):
|
||||||
|
# 可能关闭/开启 controlnet
|
||||||
|
# 可能开启了 controlnet,切换 controlnet 的方法
|
||||||
|
old_sd_controlnet = self.sd_controlnet
|
||||||
|
old_sd_controlnet_method = self.sd_controlnet_method
|
||||||
|
self.sd_controlnet = config.controlnet_enabled
|
||||||
|
self.sd_controlnet_method = config.controlnet_method
|
||||||
|
|
||||||
old_method = self.kwargs["sd_controlnet_method"]
|
|
||||||
self.kwargs["sd_controlnet_method"] = control_method
|
|
||||||
self.model = self.init_model(
|
self.model = self.init_model(
|
||||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||||
)
|
)
|
||||||
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
|
if not config.controlnet_enabled:
|
||||||
|
logger.info(f"Disable controlnet")
|
||||||
|
elif old_sd_controlnet_method != config.controlnet_method:
|
||||||
|
logger.info(
|
||||||
|
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||||
|
|
||||||
def enable_disable_freeu(self, config: Config):
|
def enable_disable_freeu(self, config: Config):
|
||||||
if str(self.model.device) == "mps":
|
if str(self.model.device) == "mps":
|
||||||
@ -120,7 +146,7 @@ class ModelManager:
|
|||||||
def enable_disable_lcm_lora(self, config: Config):
|
def enable_disable_lcm_lora(self, config: Config):
|
||||||
if self.available_models[self.name].support_lcm_lora:
|
if self.available_models[self.name].support_lcm_lora:
|
||||||
if config.sd_lcm_lora:
|
if config.sd_lcm_lora:
|
||||||
if not self.model.model.pipe.get_list_adapters():
|
if not self.model.model.get_list_adapters():
|
||||||
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
||||||
else:
|
else:
|
||||||
self.model.model.disable_lora()
|
self.model.model.disable_lora()
|
||||||
|
@ -234,16 +234,6 @@ def parse_args():
|
|||||||
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.sd_local_model_path and args.model == "sd1.5":
|
|
||||||
if not os.path.exists(args.sd_local_model_path):
|
|
||||||
parser.error(
|
|
||||||
f"invalid --sd-local-model-path: {args.sd_local_model_path} not exists"
|
|
||||||
)
|
|
||||||
if not os.path.isfile(args.sd_local_model_path):
|
|
||||||
parser.error(
|
|
||||||
f"invalid --sd-local-model-path: {args.sd_local_model_path} is a directory"
|
|
||||||
)
|
|
||||||
|
|
||||||
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
|
os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR
|
||||||
if args.model_dir and args.model_dir is not None:
|
if args.model_dir and args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
@ -264,7 +254,7 @@ def parse_args():
|
|||||||
scanned_models = scan_models()
|
scanned_models = scan_models()
|
||||||
if args.model not in [it.name for it in scanned_models]:
|
if args.model not in [it.name for it in scanned_models]:
|
||||||
parser.error(
|
parser.error(
|
||||||
f"invalid --model: {args.model} not exists. Available models: {AVAILABLE_MODELS} or {scanned_models}"
|
f"invalid --model: {args.model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.input and args.input is not None:
|
if args.input and args.input is not None:
|
||||||
|
@ -65,6 +65,28 @@ class ModelInfo(BaseModel):
|
|||||||
return SD_CONTROLNET_CHOICES
|
return SD_CONTROLNET_CHOICES
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_strength(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_outpainting(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
] or self.name in [
|
||||||
|
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||||
|
]
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def support_lcm_lora(self) -> bool:
|
def support_lcm_lora(self) -> bool:
|
||||||
@ -129,10 +151,10 @@ class SDSampler(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class FREEUConfig(BaseModel):
|
class FREEUConfig(BaseModel):
|
||||||
s1: float = 1.0
|
s1: float = 0.9
|
||||||
s2: float = 1.0
|
s2: float = 0.2
|
||||||
b1: float = 1.0
|
b1: float = 1.2
|
||||||
b2: float = 1.0
|
b2: float = 1.4
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
@ -140,18 +162,18 @@ class Config(BaseModel):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
# Configs for ldm model
|
# Configs for ldm model
|
||||||
ldm_steps: int
|
ldm_steps: int = 20
|
||||||
ldm_sampler: str = LDMSampler.plms
|
ldm_sampler: str = LDMSampler.plms
|
||||||
|
|
||||||
# Configs for zits model
|
# Configs for zits model
|
||||||
zits_wireframe: bool = True
|
zits_wireframe: bool = True
|
||||||
|
|
||||||
# Configs for High Resolution Strategy(different way to preprocess image)
|
# Configs for High Resolution Strategy(different way to preprocess image)
|
||||||
hd_strategy: str # See HDStrategy Enum
|
hd_strategy: str = HDStrategy.CROP # See HDStrategy Enum
|
||||||
hd_strategy_crop_margin: int
|
hd_strategy_crop_margin: int = 128
|
||||||
# If the longer side of the image is larger than this value, use crop strategy
|
# If the longer side of the image is larger than this value, use crop strategy
|
||||||
hd_strategy_crop_trigger_size: int
|
hd_strategy_crop_trigger_size: int = 800
|
||||||
hd_strategy_resize_limit: int
|
hd_strategy_resize_limit: int = 1280
|
||||||
|
|
||||||
# Configs for Stable Diffusion 1.5
|
# Configs for Stable Diffusion 1.5
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
@ -159,11 +181,15 @@ class Config(BaseModel):
|
|||||||
# Crop image to this size before doing sd inpainting
|
# Crop image to this size before doing sd inpainting
|
||||||
# The value is always on the original image scale
|
# The value is always on the original image scale
|
||||||
use_croper: bool = False
|
use_croper: bool = False
|
||||||
croper_is_outpainting: bool = False
|
|
||||||
croper_x: int = None
|
croper_x: int = None
|
||||||
croper_y: int = None
|
croper_y: int = None
|
||||||
croper_height: int = None
|
croper_height: int = None
|
||||||
croper_width: int = None
|
croper_width: int = None
|
||||||
|
use_extender: bool = False
|
||||||
|
extender_x: int = None
|
||||||
|
extender_y: int = None
|
||||||
|
extender_height: int = None
|
||||||
|
extender_width: int = None
|
||||||
|
|
||||||
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
|
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
|
||||||
# Used by sd models and paint_by_example model
|
# Used by sd models and paint_by_example model
|
||||||
@ -207,18 +233,12 @@ class Config(BaseModel):
|
|||||||
cv2_radius: int = 4
|
cv2_radius: int = 4
|
||||||
|
|
||||||
# Paint by Example
|
# Paint by Example
|
||||||
paint_by_example_steps: int = 50
|
|
||||||
paint_by_example_guidance_scale: float = 7.5
|
|
||||||
paint_by_example_mask_blur: int = 0
|
|
||||||
paint_by_example_seed: int = 42
|
|
||||||
paint_by_example_match_histograms: bool = False
|
|
||||||
paint_by_example_example_image: Optional[Image] = None
|
paint_by_example_example_image: Optional[Image] = None
|
||||||
|
|
||||||
# InstructPix2Pix
|
# InstructPix2Pix
|
||||||
p2p_steps: int = 50
|
|
||||||
p2p_image_guidance_scale: float = 1.5
|
p2p_image_guidance_scale: float = 1.5
|
||||||
p2p_guidance_scale: float = 7.5
|
|
||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
|
controlnet_enabled: bool = False
|
||||||
controlnet_conditioning_scale: float = 0.4
|
controlnet_conditioning_scale: float = 0.4
|
||||||
controlnet_method: str = "control_v11p_sd15_canny"
|
controlnet_method: str = "control_v11p_sd15_canny"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
@ -103,7 +104,7 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
|||||||
|
|
||||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||||
app.config["JSON_AS_ASCII"] = False
|
app.config["JSON_AS_ASCII"] = False
|
||||||
CORS(app, expose_headers=["Content-Disposition"])
|
CORS(app, expose_headers=["Content-Disposition", "X-seed"])
|
||||||
|
|
||||||
sio_logger = logging.getLogger("sio-logger")
|
sio_logger = logging.getLogger("sio-logger")
|
||||||
sio_logger.setLevel(logging.ERROR)
|
sio_logger.setLevel(logging.ERROR)
|
||||||
@ -115,8 +116,6 @@ output_dir: str = None
|
|||||||
device = None
|
device = None
|
||||||
input_image_path: str = None
|
input_image_path: str = None
|
||||||
is_disable_model_switch: bool = False
|
is_disable_model_switch: bool = False
|
||||||
is_controlnet: bool = False
|
|
||||||
controlnet_method: str = "control_v11p_sd15_canny"
|
|
||||||
enable_file_manager: bool = False
|
enable_file_manager: bool = False
|
||||||
enable_auto_saving: bool = False
|
enable_auto_saving: bool = False
|
||||||
is_desktop: bool = False
|
is_desktop: bool = False
|
||||||
@ -266,26 +265,21 @@ def process():
|
|||||||
sd_guidance_scale=form["sdGuidanceScale"],
|
sd_guidance_scale=form["sdGuidanceScale"],
|
||||||
sd_sampler=form["sdSampler"],
|
sd_sampler=form["sdSampler"],
|
||||||
sd_seed=form["sdSeed"],
|
sd_seed=form["sdSeed"],
|
||||||
|
sd_freeu=form["enableFreeu"],
|
||||||
|
sd_freeu_config=json.loads(form["freeuConfig"]),
|
||||||
|
sd_lcm_lora=form["enableLCMLora"],
|
||||||
sd_match_histograms=form["sdMatchHistograms"],
|
sd_match_histograms=form["sdMatchHistograms"],
|
||||||
cv2_flag=form["cv2Flag"],
|
cv2_flag=form["cv2Flag"],
|
||||||
cv2_radius=form["cv2Radius"],
|
cv2_radius=form["cv2Radius"],
|
||||||
paint_by_example_steps=form["paintByExampleSteps"],
|
|
||||||
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
|
|
||||||
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
|
|
||||||
paint_by_example_seed=form["paintByExampleSeed"],
|
|
||||||
paint_by_example_match_histograms=form["paintByExampleMatchHistograms"],
|
|
||||||
paint_by_example_example_image=paint_by_example_example_image,
|
paint_by_example_example_image=paint_by_example_example_image,
|
||||||
p2p_steps=form["p2pSteps"],
|
|
||||||
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
||||||
p2p_guidance_scale=form["p2pGuidanceScale"],
|
controlnet_enabled=form["controlnet_enabled"],
|
||||||
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
||||||
controlnet_method=form["controlnet_method"],
|
controlnet_method=form["controlnet_method"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
if config.sd_seed == -1:
|
||||||
config.sd_seed = random.randint(1, 999999999)
|
config.sd_seed = random.randint(1, 99999999)
|
||||||
if config.paint_by_example_seed == -1:
|
|
||||||
config.paint_by_example_seed = random.randint(1, 999999999)
|
|
||||||
|
|
||||||
logger.info(f"Origin image shape: {original_shape}")
|
logger.info(f"Origin image shape: {original_shape}")
|
||||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||||
@ -424,6 +418,8 @@ def get_server_config():
|
|||||||
"plugins": list(plugins.keys()),
|
"plugins": list(plugins.keys()),
|
||||||
"enableFileManager": enable_file_manager,
|
"enableFileManager": enable_file_manager,
|
||||||
"enableAutoSaving": enable_auto_saving,
|
"enableAutoSaving": enable_auto_saving,
|
||||||
|
"enableControlnet": model.sd_controlnet,
|
||||||
|
"controlnetMethod": model.sd_controlnet_method,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
@ -540,18 +536,12 @@ def main(args):
|
|||||||
global is_desktop
|
global is_desktop
|
||||||
global thumb
|
global thumb
|
||||||
global output_dir
|
global output_dir
|
||||||
global is_controlnet
|
|
||||||
global controlnet_method
|
|
||||||
global image_quality
|
global image_quality
|
||||||
|
global enable_auto_saving
|
||||||
|
|
||||||
build_plugins(args)
|
build_plugins(args)
|
||||||
|
|
||||||
image_quality = args.quality
|
image_quality = args.quality
|
||||||
|
|
||||||
if args.sd_controlnet and args.model in SD15_MODELS:
|
|
||||||
is_controlnet = True
|
|
||||||
controlnet_method = args.sd_controlnet_method
|
|
||||||
|
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
if output_dir:
|
if output_dir:
|
||||||
output_dir = os.path.abspath(output_dir)
|
output_dir = os.path.abspath(output_dir)
|
||||||
@ -609,9 +599,6 @@ def main(args):
|
|||||||
hf_access_token=args.hf_access_token,
|
hf_access_token=args.hf_access_token,
|
||||||
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
|
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
|
||||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||||
sd_run_local=args.sd_run_local,
|
|
||||||
sd_local_model_path=args.sd_local_model_path,
|
|
||||||
local_files_only=args.local_files_only,
|
|
||||||
cpu_offload=args.cpu_offload,
|
cpu_offload=args.cpu_offload,
|
||||||
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
|
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
|
||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
|
@ -39,7 +39,6 @@ def test_runway_sd_1_5(
|
|||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
|
||||||
disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
sd_controlnet_method=sd_controlnet_method,
|
sd_controlnet_method=sd_controlnet_method,
|
||||||
@ -88,11 +87,9 @@ def test_local_file_path(sd_device, sampler):
|
|||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
|
|
||||||
sd_controlnet_method="control_v11p_sd15_canny",
|
sd_controlnet_method="control_v11p_sd15_canny",
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
@ -128,7 +125,6 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
|||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
@ -170,7 +166,6 @@ def test_controlnet_switch(sd_device, sampler):
|
|||||||
sd_controlnet=True,
|
sd_controlnet=True,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
|
@ -8,23 +8,30 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
|
|||||||
from lama_cleaner.schema import HDStrategy
|
from lama_cleaner.schema import HDStrategy
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / 'result'
|
save_dir = current_dir / "result"
|
||||||
save_dir.mkdir(exist_ok=True, parents=True)
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'mps'
|
device = "cuda" if torch.cuda.is_available() else "mps"
|
||||||
|
model_name = "timbrooks/instruct-pix2pix"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||||
@pytest.mark.parametrize("cpu_offload", [False, True])
|
@pytest.mark.parametrize("cpu_offload", [False, True])
|
||||||
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
||||||
sd_steps = 50 if device == 'cuda' else 20
|
sd_steps = 50 if device == "cuda" else 20
|
||||||
model = ModelManager(name="instruct_pix2pix",
|
model = ModelManager(
|
||||||
|
name=model_name,
|
||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
|
||||||
disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=cpu_offload)
|
cpu_offload=cpu_offload,
|
||||||
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1)
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt="What if it were snowing?",
|
||||||
|
p2p_steps=sd_steps,
|
||||||
|
sd_scale=1.1,
|
||||||
|
)
|
||||||
|
|
||||||
name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
|
name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
|
||||||
|
|
||||||
@ -34,22 +41,27 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
|||||||
f"instruct_pix2pix_{name}.png",
|
f"instruct_pix2pix_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3
|
fx=1.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("disable_nsfw", [False])
|
@pytest.mark.parametrize("disable_nsfw", [False])
|
||||||
@pytest.mark.parametrize("cpu_offload", [False])
|
@pytest.mark.parametrize("cpu_offload", [False])
|
||||||
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
|
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
|
||||||
sd_steps = 50 if device == 'cuda' else 20
|
sd_steps = 50 if device == "cuda" else 20
|
||||||
model = ModelManager(name="instruct_pix2pix",
|
model = ModelManager(
|
||||||
|
name=model_name,
|
||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
|
||||||
disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=cpu_offload)
|
cpu_offload=cpu_offload,
|
||||||
cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps)
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt="What if it were snowing?",
|
||||||
|
p2p_steps=sd_steps,
|
||||||
|
)
|
||||||
|
|
||||||
name = f"snow"
|
name = f"snow"
|
||||||
|
|
||||||
|
@ -20,8 +20,6 @@ def test_load_model():
|
|||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
disable_nsfw=False,
|
disable_nsfw=False,
|
||||||
sd_cpu_textencoder=True,
|
sd_cpu_textencoder=True,
|
||||||
sd_run_local=True,
|
|
||||||
local_files_only=True,
|
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
enable_xformers=False,
|
enable_xformers=False,
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,80 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_switch():
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
|
sd_controlnet=True,
|
||||||
|
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||||
|
device=torch.device("mps"),
|
||||||
|
hf_access_token="",
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=True,
|
||||||
|
cpu_offload=False,
|
||||||
|
enable_xformers=False,
|
||||||
|
callback=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.switch("lama")
|
||||||
|
|
||||||
|
|
||||||
|
def test_controlnet_switch_onoff(caplog):
|
||||||
|
name = "runwayml/stable-diffusion-inpainting"
|
||||||
|
model = ModelManager(
|
||||||
|
name=name,
|
||||||
|
sd_controlnet=True,
|
||||||
|
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||||
|
device=torch.device("mps"),
|
||||||
|
hf_access_token="",
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=True,
|
||||||
|
cpu_offload=False,
|
||||||
|
enable_xformers=False,
|
||||||
|
callback=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.switch_controlnet_method(
|
||||||
|
Config(
|
||||||
|
name=name,
|
||||||
|
controlnet_enabled=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Disable controlnet" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_controlnet_switch_method(caplog):
|
||||||
|
name = "runwayml/stable-diffusion-inpainting"
|
||||||
|
old_method = "lllyasviel/control_v11p_sd15_canny"
|
||||||
|
new_method = "lllyasviel/control_v11p_sd15_openpose"
|
||||||
|
model = ModelManager(
|
||||||
|
name=name,
|
||||||
|
sd_controlnet=True,
|
||||||
|
sd_controlnet_method=old_method,
|
||||||
|
device=torch.device("mps"),
|
||||||
|
hf_access_token="",
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=True,
|
||||||
|
cpu_offload=False,
|
||||||
|
enable_xformers=False,
|
||||||
|
callback=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.switch_controlnet_method(
|
||||||
|
Config(
|
||||||
|
name=name,
|
||||||
|
controlnet_enabled=True,
|
||||||
|
controlnet_method=new_method,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert f"Switch Controlnet method from {old_method} to {new_method}" in caplog.text
|
@ -17,7 +17,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["sd1.5"])
|
@pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"])
|
||||||
@pytest.mark.parametrize("sd_device", ["mps"])
|
@pytest.mark.parametrize("sd_device", ["mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
@ -42,7 +42,6 @@ def test_outpainting(name, sd_device, rect):
|
|||||||
name=name,
|
name=name,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
@ -51,12 +50,11 @@ def test_outpainting(name, sd_device, rect):
|
|||||||
HDStrategy.ORIGINAL,
|
HDStrategy.ORIGINAL,
|
||||||
prompt="a dog sitting on a bench in the park",
|
prompt="a dog sitting on a bench in the park",
|
||||||
sd_steps=50,
|
sd_steps=50,
|
||||||
use_croper=True,
|
use_extender=True,
|
||||||
croper_is_outpainting=True,
|
extender_x=rect[0],
|
||||||
croper_x=rect[0],
|
extender_y=rect[1],
|
||||||
croper_y=rect[1],
|
extender_width=rect[2],
|
||||||
croper_width=rect[2],
|
extender_height=rect[3],
|
||||||
croper_height=rect[3],
|
|
||||||
sd_guidance_scale=8.0,
|
sd_guidance_scale=8.0,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus,
|
||||||
)
|
)
|
||||||
@ -64,13 +62,13 @@ def test_outpainting(name, sd_device, rect):
|
|||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
f"{name.replace('/', '--')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["kandinsky2.2"])
|
@pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"])
|
||||||
@pytest.mark.parametrize("sd_device", ["mps"])
|
@pytest.mark.parametrize("sd_device", ["mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
@ -86,10 +84,9 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
|||||||
return
|
return
|
||||||
|
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name=name,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
@ -99,12 +96,11 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
|||||||
prompt="a cat",
|
prompt="a cat",
|
||||||
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
|
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
|
||||||
sd_steps=50,
|
sd_steps=50,
|
||||||
use_croper=True,
|
use_extender=True,
|
||||||
croper_is_outpainting=True,
|
extender_x=rect[0],
|
||||||
croper_x=rect[0],
|
extender_y=rect[1],
|
||||||
croper_y=rect[1],
|
extender_width=rect[2],
|
||||||
croper_width=rect[2],
|
extender_height=rect[3],
|
||||||
croper_height=rect[3],
|
|
||||||
sd_guidance_scale=7,
|
sd_guidance_scale=7,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus,
|
||||||
)
|
)
|
||||||
@ -112,7 +108,7 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
|||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"{name.replace('.', '_')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
f"{name.replace('/', '--')}_outpainting_dpm++_{'_'.join(map(str, rect))}.png",
|
||||||
img_p=current_dir / "cat.png",
|
img_p=current_dir / "cat.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1,
|
fx=1,
|
||||||
|
@ -10,15 +10,19 @@ from lama_cleaner.schema import HDStrategy
|
|||||||
from lama_cleaner.tests.test_model import get_config, get_data
|
from lama_cleaner.tests.test_model import get_config, get_data
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / 'result'
|
save_dir = current_dir / "result"
|
||||||
save_dir.mkdir(exist_ok=True, parents=True)
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = "cuda" if torch.cuda.is_available() else "mps"
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
model_name = "Fantasy-Studio/Paint-by-Example"
|
||||||
|
|
||||||
|
|
||||||
def assert_equal(
|
def assert_equal(
|
||||||
model, config, gt_name,
|
model,
|
||||||
fx: float = 1, fy: float = 1,
|
config,
|
||||||
|
gt_name,
|
||||||
|
fx: float = 1,
|
||||||
|
fy: float = 1,
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
example_p=current_dir / "bunny.jpeg",
|
example_p=current_dir / "bunny.jpeg",
|
||||||
@ -27,7 +31,9 @@ def assert_equal(
|
|||||||
|
|
||||||
example_image = cv2.imread(str(example_p))
|
example_image = cv2.imread(str(example_p))
|
||||||
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
|
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
|
||||||
example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
example_image = cv2.resize(
|
||||||
|
example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
||||||
config.paint_by_example_example_image = Image.fromarray(example_image)
|
config.paint_by_example_example_image = Image.fromarray(example_image)
|
||||||
@ -35,14 +41,13 @@ def assert_equal(
|
|||||||
cv2.imwrite(str(save_dir / gt_name), res)
|
cv2.imwrite(str(save_dir / gt_name), res)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
def test_paint_by_example():
|
||||||
def test_paint_by_example(strategy):
|
model = ModelManager(name=model_name, device=device, disable_nsfw=True)
|
||||||
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
|
cfg = get_config(HDStrategy.ORIGINAL, sd_steps=30)
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30)
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"paint_by_example_{strategy.capitalize()}.png",
|
f"paint_by_example.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fy=0.9,
|
fy=0.9,
|
||||||
@ -50,57 +55,31 @@ def test_paint_by_example(strategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
def test_paint_by_example_cpu_offload():
|
||||||
def test_paint_by_example_disable_nsfw(strategy):
|
model = ModelManager(
|
||||||
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False)
|
name=model_name, device=device, cpu_offload=True, disable_nsfw=False
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30)
|
)
|
||||||
|
cfg = get_config(HDStrategy.ORIGINAL, sd_steps=30)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png",
|
f"paint_by_example_cpu_offload.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
def test_paint_by_example_cpu_offload_cpu_device():
|
||||||
def test_paint_by_example_sd_scale(strategy):
|
model = ModelManager(
|
||||||
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
|
name=model_name, device=torch.device("cpu"), cpu_offload=True, disable_nsfw=True
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
)
|
||||||
|
cfg = get_config(HDStrategy.ORIGINAL, sd_steps=1)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"paint_by_example_{strategy.capitalize()}_sdscale.png",
|
f"paint_by_example_cpu_offload_cpu_device.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fy=0.9,
|
fy=0.9,
|
||||||
fx=1.3
|
fx=1.3,
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
def test_paint_by_example_cpu_offload(strategy):
|
|
||||||
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False)
|
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
def test_paint_by_example_cpu_offload_cpu_device(strategy):
|
|
||||||
model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True)
|
|
||||||
cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"paint_by_example_{strategy.capitalize()}_cpu_offload_cpu_device.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
fy=0.9,
|
|
||||||
fx=1.3
|
|
||||||
)
|
)
|
||||||
|
@ -7,7 +7,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.schema import HDStrategy, SDSampler
|
from lama_cleaner.schema import HDStrategy, SDSampler, FREEUConfig
|
||||||
from lama_cleaner.tests.test_model import get_config, assert_equal
|
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
@ -16,178 +16,127 @@ save_dir.mkdir(exist_ok=True, parents=True)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
|
||||||
def test_runway_sd_1_5_ddim(
|
|
||||||
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
|
|
||||||
):
|
|
||||||
def callback(i, t, latents):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
|
||||||
return
|
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == "cuda" else 1
|
|
||||||
model = ModelManager(
|
|
||||||
name="sd1.5",
|
|
||||||
device=torch.device(sd_device),
|
|
||||||
hf_access_token="",
|
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=disable_nsfw,
|
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
fx=1.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]
|
"sampler",
|
||||||
|
[
|
||||||
|
SDSampler.ddim,
|
||||||
|
SDSampler.pndm,
|
||||||
|
SDSampler.k_lms,
|
||||||
|
SDSampler.k_euler,
|
||||||
|
SDSampler.k_euler_a,
|
||||||
|
SDSampler.lcm,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [False])
|
def test_runway_sd_1_5_all_samplers(
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True])
|
sd_device,
|
||||||
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
sampler,
|
||||||
def callback(i, t, latents):
|
):
|
||||||
print(f"sd_step_{i}")
|
|
||||||
|
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == "cuda" else 1
|
sd_steps = 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
disable_nsfw=True,
|
||||||
disable_nsfw=disable_nsfw,
|
sd_cpu_textencoder=False,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
)
|
||||||
callback=callback,
|
cfg = get_config(
|
||||||
|
HDStrategy.ORIGINAL, prompt="a fox sitting on a bench", sd_steps=sd_steps
|
||||||
)
|
)
|
||||||
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
name = f"device_{sd_device}_{sampler}"
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
f"runway_sd_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.lcm])
|
||||||
@pytest.mark.parametrize("sd_prevent_unmasked_area", [False, True])
|
def test_runway_sd_lcm_lora(sd_device, strategy, sampler):
|
||||||
def test_runway_sd_1_5_negative_prompt(
|
|
||||||
sd_device, strategy, sampler, sd_prevent_unmasked_area
|
|
||||||
):
|
|
||||||
def callback(i, t, latents):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == "cuda" else 20
|
sd_steps = 5
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
disable_nsfw=True,
|
||||||
disable_nsfw=False,
|
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
callback=callback,
|
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
strategy,
|
strategy,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
sd_steps=sd_steps,
|
sd_steps=sd_steps,
|
||||||
prompt="Face of a fox, high resolution, sitting on a park bench",
|
sd_guidance_scale=2,
|
||||||
negative_prompt="orange, yellow, small",
|
sd_lcm_lora=True,
|
||||||
sd_sampler=sampler,
|
|
||||||
sd_match_histograms=True,
|
|
||||||
sd_prevent_unmasked_area=sd_prevent_unmasked_area,
|
|
||||||
)
|
|
||||||
|
|
||||||
name = f"{sampler}_negative_prompt"
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"runway_sd_{strategy.capitalize()}_{name}_prevent_unmasked_area_{sd_prevent_unmasked_area}.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
fx=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [False])
|
|
||||||
@pytest.mark.parametrize("disable_nsfw", [False])
|
|
||||||
def test_runway_sd_1_5_sd_scale(
|
|
||||||
sd_device, strategy, sampler, cpu_textencoder, disable_nsfw
|
|
||||||
):
|
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
|
||||||
return
|
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == "cuda" else 20
|
|
||||||
model = ModelManager(
|
|
||||||
name="sd1.5",
|
|
||||||
device=torch.device(sd_device),
|
|
||||||
hf_access_token="",
|
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=disable_nsfw,
|
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
|
||||||
)
|
|
||||||
cfg = get_config(
|
|
||||||
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
|
|
||||||
)
|
)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
|
f"runway_sd_1_5_lcm_lora.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
def test_runway_sd_freeu(sd_device, strategy, sampler):
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 30
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_guidance_scale=7.5,
|
||||||
|
sd_freeu=True,
|
||||||
|
sd_freeu_config=FREEUConfig(),
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"runway_sd_1_5_freeu.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
def test_runway_sd_sd_strength(sd_device, strategy, sampler):
|
def test_runway_sd_sd_strength(sd_device, strategy, sampler):
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == "cuda" else 20
|
sd_steps = 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
)
|
)
|
||||||
@ -205,6 +154,33 @@ def test_runway_sd_sd_strength(sd_device, strategy, sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
def test_runway_norm_sd_model(sd_device, strategy, sampler):
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 30
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-v1-5",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
)
|
||||||
|
cfg = get_config(strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"runway_{sd_device}_norm_sd_model.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda"])
|
@pytest.mark.parametrize("sd_device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||||
@ -212,19 +188,16 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
|||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 50 if sd_device == "cuda" else 20
|
sd_steps = 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps)
|
||||||
strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_scale=0.85
|
|
||||||
)
|
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}"
|
name = f"device_{sd_device}_{sampler}"
|
||||||
@ -239,28 +212,27 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"local_model_path",
|
"name",
|
||||||
[
|
[
|
||||||
"/Users/cwq/data/models/sd-v1-5-inpainting.ckpt",
|
"sd-v1-5-inpainting.ckpt",
|
||||||
"/Users/cwq/data/models/sd-v1-5-inpainting.safetensors",
|
"sd-v1-5-inpainting.safetensors",
|
||||||
|
"v1-5-pruned-emaonly.safetensors",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_local_file_path(sd_device, sampler, local_model_path):
|
def test_local_file_path(sd_device, sampler, name):
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_steps = 1 if sd_device == "cpu" else 30
|
sd_steps = 30
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.5",
|
name=name,
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True,
|
cpu_offload=False,
|
||||||
sd_local_model_path=local_model_path,
|
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
HDStrategy.ORIGINAL,
|
HDStrategy.ORIGINAL,
|
||||||
@ -269,7 +241,7 @@ def test_local_file_path(sd_device, sampler, local_model_path):
|
|||||||
)
|
)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_{Path(local_model_path).stem}"
|
name = f"device_{sd_device}_{sampler}_{name}"
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
|
@ -7,7 +7,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
from lama_cleaner.schema import HDStrategy, SDSampler
|
from lama_cleaner.schema import HDStrategy, SDSampler, FREEUConfig
|
||||||
from lama_cleaner.tests.test_model import get_config, assert_equal
|
from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
@ -15,12 +15,10 @@ save_dir = current_dir / "result"
|
|||||||
save_dir.mkdir(exist_ok=True, parents=True)
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["mps"])
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [False])
|
def test_sdxl(sd_device, strategy, sampler):
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True])
|
|
||||||
def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
|
||||||
def callback(i, t, latents):
|
def callback(i, t, latents):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -29,24 +27,23 @@ def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
|||||||
|
|
||||||
sd_steps = 20
|
sd_steps = 20
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sdxl",
|
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=False,
|
disable_nsfw=True,
|
||||||
disable_nsfw=disable_nsfw,
|
sd_cpu_textencoder=False,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
strategy,
|
strategy,
|
||||||
prompt="a fox sitting on a bench",
|
prompt="face of a fox, sitting on a bench",
|
||||||
sd_steps=sd_steps,
|
sd_steps=sd_steps,
|
||||||
sd_strength=0.99,
|
sd_strength=1.0,
|
||||||
sd_guidance_scale=7.0,
|
sd_guidance_scale=7.0,
|
||||||
)
|
)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
name = f"device_{sd_device}_{sampler}"
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
@ -59,6 +56,67 @@ def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
def test_sdxl_lcm_lora_and_freeu(sd_device, strategy, sampler):
|
||||||
|
def callback(i, t, latents):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 5
|
||||||
|
model = ModelManager(
|
||||||
|
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_strength=1.0,
|
||||||
|
sd_guidance_scale=2.0,
|
||||||
|
sd_lcm_lora=True,
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
name = f"device_{sd_device}_{sampler}"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sdxl_{name}_lcm_lora.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=2,
|
||||||
|
fy=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = get_config(
|
||||||
|
strategy,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_guidance_scale=7.5,
|
||||||
|
sd_freeu=True,
|
||||||
|
sd_freeu_config=FREEUConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sdxl_{name}_freeu.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=2,
|
||||||
|
fy=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ["mps"])
|
@pytest.mark.parametrize("sd_device", ["mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
@ -67,33 +125,26 @@ def test_sdxl(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sdxl_outpainting(sd_device, rect):
|
def test_sdxl_outpainting(sd_device, rect):
|
||||||
def callback(i, t, latents):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sdxl",
|
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
callback=callback,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
HDStrategy.ORIGINAL,
|
HDStrategy.ORIGINAL,
|
||||||
prompt="a dog sitting on a bench in the park",
|
prompt="a dog sitting on a bench in the park",
|
||||||
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
|
|
||||||
sd_steps=20,
|
sd_steps=20,
|
||||||
use_croper=True,
|
use_extender=True,
|
||||||
croper_is_outpainting=True,
|
extender_x=rect[0],
|
||||||
croper_x=rect[0],
|
extender_y=rect[1],
|
||||||
croper_y=rect[1],
|
extender_width=rect[2],
|
||||||
croper_width=rect[2],
|
extender_height=rect[3],
|
||||||
croper_height=rect[3],
|
|
||||||
sd_strength=1.0,
|
sd_strength=1.0,
|
||||||
sd_guidance_scale=8.0,
|
sd_guidance_scale=8.0,
|
||||||
sd_sampler=SDSampler.ddim,
|
sd_sampler=SDSampler.ddim,
|
||||||
|
Loading…
Reference in New Issue
Block a user