sdxl support cpu_text_encoder

This commit is contained in:
Qing 2024-01-10 13:34:11 +08:00
parent 05a15b2e1f
commit 38b6edacf0
5 changed files with 80 additions and 9 deletions

View File

@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel):
def __init__(self, text_encoder, torch_dtype): def __init__(self, text_encoder, torch_dtype):
super().__init__(text_encoder.config) super().__init__(text_encoder.config)
self.config = text_encoder.config self.config = text_encoder.config
# cpu not support float16
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
@ -16,11 +17,15 @@ class CPUTextEncoderWrapper(PreTrainedModel):
def __call__(self, x, **kwargs): def __call__(self, x, **kwargs):
input_device = x.device input_device = x.device
return [ original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] for k, v in original_output.items():
.to(input_device) if isinstance(v, tuple):
.to(self.torch_dtype) original_output[k] = [
v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
] ]
else:
original_output[k] = v.to(input_device).to(self.torch_dtype)
return original_output
@property @property
def dtype(self): def dtype(self):

View File

@ -9,6 +9,7 @@ from loguru import logger
from iopaint.schema import InpaintRequest, ModelType from iopaint.schema import InpaintRequest, ModelType
from .base import DiffusionInpaintModel from .base import DiffusionInpaintModel
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
@ -37,11 +38,11 @@ class SDXL(DiffusionInpaintModel):
) )
else: else:
model_kwargs = {**kwargs.get("pipe_components", {})} model_kwargs = {**kwargs.get("pipe_components", {})}
if 'vae' not in model_kwargs: if "vae" not in model_kwargs:
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
) )
model_kwargs['vae'] = vae model_kwargs["vae"] = vae
self.model = handle_from_pretrained_exceptions( self.model = handle_from_pretrained_exceptions(
StableDiffusionXLInpaintPipeline.from_pretrained, StableDiffusionXLInpaintPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path, pretrained_model_name_or_path=self.model_id_or_path,
@ -58,7 +59,13 @@ class SDXL(DiffusionInpaintModel):
else: else:
self.model = self.model.to(device) self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]: if kwargs["sd_cpu_textencoder"]:
logger.warning("Stable Diffusion XL not support run TextEncoder on CPU") logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.model.text_encoder_2 = CPUTextEncoderWrapper(
self.model.text_encoder_2, torch_dtype
)
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)

View File

@ -147,6 +147,33 @@ def test_runway_sd_sd_strength(device, strategy, sampler):
) )
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
def test_runway_sd_cpu_textencoder(device, strategy, sampler):
sd_steps = check_device(device)
model = ModelManager(
name="runwayml/stable-diffusion-inpainting",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=True,
)
cfg = get_config(
strategy=strategy,
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
sd_sampler=sampler,
)
assert_equal(
model,
cfg,
f"runway_sd_device_{device}_cpu_textencoder.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim]) @pytest.mark.parametrize("sampler", [SDSampler.ddim])

View File

@ -44,6 +44,38 @@ def test_sdxl(device, strategy, sampler):
) )
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
def test_sdxl_cpu_text_encoder(device, strategy, sampler):
sd_steps = check_device(device)
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=True,
)
cfg = get_config(
strategy=strategy,
prompt="face of a fox, sitting on a bench",
sd_steps=sd_steps,
sd_strength=1.0,
sd_guidance_scale=7.0,
)
cfg.sd_sampler = sampler
assert_equal(
model,
cfg,
f"sdxl_device_{device}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=2,
fy=2,
)
@pytest.mark.parametrize("device", ["cuda", "mps"]) @pytest.mark.parametrize("device", ["cuda", "mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim]) @pytest.mark.parametrize("sampler", [SDSampler.ddim])

View File

@ -17,7 +17,7 @@ def check_device(device: str) -> int:
pytest.skip("CUDA is not available, skip test on cuda") pytest.skip("CUDA is not available, skip test on cuda")
if device == "mps" and not torch.backends.mps.is_available(): if device == "mps" and not torch.backends.mps.is_available():
pytest.skip("mps is not available, skip test on mps") pytest.skip("mps is not available, skip test on mps")
steps = 1 if device == "cpu" else 20 steps = 2 if device == "cpu" else 20
return steps return steps