sdxl support cpu_text_encoder
This commit is contained in:
parent
05a15b2e1f
commit
38b6edacf0
@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__(text_encoder.config)
|
||||
self.config = text_encoder.config
|
||||
# cpu not support float16
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
self.torch_dtype = torch_dtype
|
||||
@ -16,11 +17,15 @@ class CPUTextEncoderWrapper(PreTrainedModel):
|
||||
|
||||
def __call__(self, x, **kwargs):
|
||||
input_device = x.device
|
||||
return [
|
||||
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
||||
.to(input_device)
|
||||
.to(self.torch_dtype)
|
||||
original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
|
||||
for k, v in original_output.items():
|
||||
if isinstance(v, tuple):
|
||||
original_output[k] = [
|
||||
v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
|
||||
]
|
||||
else:
|
||||
original_output[k] = v.to(input_device).to(self.torch_dtype)
|
||||
return original_output
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
|
@ -9,6 +9,7 @@ from loguru import logger
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
|
||||
from .base import DiffusionInpaintModel
|
||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||
|
||||
|
||||
@ -37,11 +38,11 @@ class SDXL(DiffusionInpaintModel):
|
||||
)
|
||||
else:
|
||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||
if 'vae' not in model_kwargs:
|
||||
if "vae" not in model_kwargs:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||
)
|
||||
model_kwargs['vae'] = vae
|
||||
model_kwargs["vae"] = vae
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionXLInpaintPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
@ -58,7 +59,13 @@ class SDXL(DiffusionInpaintModel):
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.warning("Stable Diffusion XL not support run TextEncoder on CPU")
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
self.model.text_encoder_2 = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder_2, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
|
@ -147,6 +147,33 @@ def test_runway_sd_sd_strength(device, strategy, sampler):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_runway_sd_cpu_textencoder(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_sampler=sampler,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_device_{device}_cpu_textencoder.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
|
@ -44,6 +44,38 @@ def test_sdxl(device, strategy, sampler):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_sdxl_cpu_text_encoder(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_strength=1.0,
|
||||
sd_guidance_scale=7.0,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sdxl_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=2,
|
||||
fy=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
|
@ -17,7 +17,7 @@ def check_device(device: str) -> int:
|
||||
pytest.skip("CUDA is not available, skip test on cuda")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("mps is not available, skip test on mps")
|
||||
steps = 1 if device == "cpu" else 20
|
||||
steps = 2 if device == "cpu" else 20
|
||||
return steps
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user