sdxl support cpu_text_encoder
This commit is contained in:
parent
05a15b2e1f
commit
38b6edacf0
@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel):
|
|||||||
def __init__(self, text_encoder, torch_dtype):
|
def __init__(self, text_encoder, torch_dtype):
|
||||||
super().__init__(text_encoder.config)
|
super().__init__(text_encoder.config)
|
||||||
self.config = text_encoder.config
|
self.config = text_encoder.config
|
||||||
|
# cpu not support float16
|
||||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
@ -16,11 +17,15 @@ class CPUTextEncoderWrapper(PreTrainedModel):
|
|||||||
|
|
||||||
def __call__(self, x, **kwargs):
|
def __call__(self, x, **kwargs):
|
||||||
input_device = x.device
|
input_device = x.device
|
||||||
return [
|
original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
|
||||||
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
for k, v in original_output.items():
|
||||||
.to(input_device)
|
if isinstance(v, tuple):
|
||||||
.to(self.torch_dtype)
|
original_output[k] = [
|
||||||
]
|
v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
original_output[k] = v.to(input_device).to(self.torch_dtype)
|
||||||
|
return original_output
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
@ -9,6 +9,7 @@ from loguru import logger
|
|||||||
from iopaint.schema import InpaintRequest, ModelType
|
from iopaint.schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
from .base import DiffusionInpaintModel
|
from .base import DiffusionInpaintModel
|
||||||
|
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||||
|
|
||||||
|
|
||||||
@ -37,11 +38,11 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||||
if 'vae' not in model_kwargs:
|
if "vae" not in model_kwargs:
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||||
)
|
)
|
||||||
model_kwargs['vae'] = vae
|
model_kwargs["vae"] = vae
|
||||||
self.model = handle_from_pretrained_exceptions(
|
self.model = handle_from_pretrained_exceptions(
|
||||||
StableDiffusionXLInpaintPipeline.from_pretrained,
|
StableDiffusionXLInpaintPipeline.from_pretrained,
|
||||||
pretrained_model_name_or_path=self.model_id_or_path,
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
@ -58,7 +59,13 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
else:
|
else:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
if kwargs["sd_cpu_textencoder"]:
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
logger.warning("Stable Diffusion XL not support run TextEncoder on CPU")
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
self.model.text_encoder_2 = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder_2, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
@ -147,6 +147,33 @@ def test_runway_sd_sd_strength(device, strategy, sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
def test_runway_sd_cpu_textencoder(device, strategy, sampler):
|
||||||
|
sd_steps = check_device(device)
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
|
device=torch.device(device),
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=True,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=strategy,
|
||||||
|
prompt="a fox sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_sampler=sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"runway_sd_device_{device}_cpu_textencoder.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
@ -44,6 +44,38 @@ def test_sdxl(device, strategy, sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
def test_sdxl_cpu_text_encoder(device, strategy, sampler):
|
||||||
|
sd_steps = check_device(device)
|
||||||
|
|
||||||
|
model = ModelManager(
|
||||||
|
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||||
|
device=torch.device(device),
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=True,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=strategy,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_strength=1.0,
|
||||||
|
sd_guidance_scale=7.0,
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sdxl_device_{device}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=2,
|
||||||
|
fy=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
@ -17,7 +17,7 @@ def check_device(device: str) -> int:
|
|||||||
pytest.skip("CUDA is not available, skip test on cuda")
|
pytest.skip("CUDA is not available, skip test on cuda")
|
||||||
if device == "mps" and not torch.backends.mps.is_available():
|
if device == "mps" and not torch.backends.mps.is_available():
|
||||||
pytest.skip("mps is not available, skip test on mps")
|
pytest.skip("mps is not available, skip test on mps")
|
||||||
steps = 1 if device == "cpu" else 20
|
steps = 2 if device == "cpu" else 20
|
||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user