From a71c3fbe1b37173366366e7bacf953815e4d7b63 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 8 Jan 2024 23:53:20 +0800 Subject: [PATCH] clean code: get_torch_dtype; mps use float16 by default --- iopaint/model/controlnet.py | 6 ++---- iopaint/model/instruct_pix2pix.py | 5 ++--- iopaint/model/kandinsky.py | 5 ++--- iopaint/model/paint_by_example.py | 5 ++--- iopaint/model/power_paint/power_paint.py | 7 ++----- iopaint/model/sd.py | 6 ++---- iopaint/model/sdxl.py | 7 ++----- iopaint/model/utils.py | 10 +++++++++- 8 files changed, 23 insertions(+), 28 deletions(-) diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index 45836de..ffaef1d 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -13,7 +13,7 @@ from .helper.controlnet_preprocess import ( make_inpaint_control_image, ) from .helper.cpu_text_encoder import CPUTextEncoderWrapper -from .utils import get_scheduler, handle_from_pretrained_exceptions +from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype class ControlNet(DiffusionInpaintModel): @@ -36,7 +36,6 @@ class ControlNet(DiffusionInpaintModel): raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}") def init_model(self, device: torch.device, **kwargs): - fp16 = not kwargs.get("no_half", False) model_info = kwargs["model_info"] controlnet_method = kwargs["controlnet_method"] @@ -54,8 +53,7 @@ class ControlNet(DiffusionInpaintModel): ) ) - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) self.torch_dtype = torch_dtype if model_info.model_type in [ diff --git a/iopaint/model/instruct_pix2pix.py b/iopaint/model/instruct_pix2pix.py index ec69927..a792f0a 100644 --- a/iopaint/model/instruct_pix2pix.py +++ b/iopaint/model/instruct_pix2pix.py @@ -6,6 +6,7 @@ from loguru import logger from iopaint.const import INSTRUCT_PIX2PIX_NAME from .base import DiffusionInpaintModel from iopaint.schema import InpaintRequest +from .utils import get_torch_dtype class InstructPix2Pix(DiffusionInpaintModel): @@ -16,7 +17,7 @@ class InstructPix2Pix(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from diffusers import StableDiffusionInstructPix2PixPipeline - fp16 = not kwargs.get("no_half", False) + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): @@ -29,8 +30,6 @@ class InstructPix2Pix(DiffusionInpaintModel): ) ) - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs ) diff --git a/iopaint/model/kandinsky.py b/iopaint/model/kandinsky.py index 9e66b0b..fe67891 100644 --- a/iopaint/model/kandinsky.py +++ b/iopaint/model/kandinsky.py @@ -6,6 +6,7 @@ import torch from iopaint.const import KANDINSKY22_NAME from .base import DiffusionInpaintModel from iopaint.schema import InpaintRequest +from .utils import get_torch_dtype class Kandinsky(DiffusionInpaintModel): @@ -15,9 +16,7 @@ class Kandinsky(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from diffusers import AutoPipelineForInpainting - fp16 = not kwargs.get("no_half", False) - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) model_kwargs = { "torch_dtype": torch_dtype, diff --git a/iopaint/model/paint_by_example.py b/iopaint/model/paint_by_example.py index 8e0abee..2ae4cad 100644 --- a/iopaint/model/paint_by_example.py +++ b/iopaint/model/paint_by_example.py @@ -7,6 +7,7 @@ from loguru import logger from iopaint.helper import decode_base64_to_image from .base import DiffusionInpaintModel from iopaint.schema import InpaintRequest +from .utils import get_torch_dtype class PaintByExample(DiffusionInpaintModel): @@ -17,9 +18,7 @@ class PaintByExample(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from diffusers import DiffusionPipeline - fp16 = not kwargs.get("no_half", False) - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): diff --git a/iopaint/model/power_paint/power_paint.py b/iopaint/model/power_paint/power_paint.py index 6c74dd9..616bb35 100644 --- a/iopaint/model/power_paint/power_paint.py +++ b/iopaint/model/power_paint/power_paint.py @@ -6,7 +6,7 @@ from loguru import logger from ..base import DiffusionInpaintModel from ..helper.cpu_text_encoder import CPUTextEncoderWrapper -from ..utils import handle_from_pretrained_exceptions +from ..utils import handle_from_pretrained_exceptions, get_torch_dtype from iopaint.schema import InpaintRequest from .powerpaint_tokenizer import add_task_to_prompt from ...const import POWERPAINT_NAME @@ -22,7 +22,7 @@ class PowerPaint(DiffusionInpaintModel): from .pipeline_powerpaint import StableDiffusionInpaintPipeline from .powerpaint_tokenizer import PowerPaintTokenizer - fp16 = not kwargs.get("no_half", False) + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") @@ -34,9 +34,6 @@ class PowerPaint(DiffusionInpaintModel): ) ) - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 - self.model = handle_from_pretrained_exceptions( StableDiffusionInpaintPipeline.from_pretrained, pretrained_model_name_or_path=self.name, diff --git a/iopaint/model/sd.py b/iopaint/model/sd.py index 6a11b32..95189f4 100644 --- a/iopaint/model/sd.py +++ b/iopaint/model/sd.py @@ -5,7 +5,7 @@ from loguru import logger from .base import DiffusionInpaintModel from .helper.cpu_text_encoder import CPUTextEncoderWrapper -from .utils import handle_from_pretrained_exceptions +from .utils import handle_from_pretrained_exceptions, get_torch_dtype from iopaint.schema import InpaintRequest, ModelType @@ -17,7 +17,7 @@ class SD(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline - fp16 = not kwargs.get("no_half", False) + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) model_kwargs = {} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): @@ -29,8 +29,6 @@ class SD(DiffusionInpaintModel): requires_safety_checker=False, ) ) - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 if self.model_info.is_single_file_diffusers: if self.model_info.model_type == ModelType.DIFFUSERS_SD: diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index 6778be7..cf1dab3 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -9,7 +9,7 @@ from loguru import logger from iopaint.schema import InpaintRequest, ModelType from .base import DiffusionInpaintModel -from .utils import handle_from_pretrained_exceptions +from .utils import handle_from_pretrained_exceptions, get_torch_dtype class SDXL(DiffusionInpaintModel): @@ -22,10 +22,7 @@ class SDXL(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): from diffusers.pipelines import StableDiffusionXLInpaintPipeline - fp16 = not kwargs.get("no_half", False) - - use_gpu = device == torch.device("cuda") and torch.cuda.is_available() - torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) if self.model_info.model_type == ModelType.DIFFUSERS_SDXL: num_in_channels = 4 diff --git a/iopaint/model/utils.py b/iopaint/model/utils.py index 1292390..80e88b3 100644 --- a/iopaint/model/utils.py +++ b/iopaint/model/utils.py @@ -1,4 +1,3 @@ -import copy import gc import math import random @@ -994,3 +993,12 @@ def handle_from_pretrained_exceptions(func, **kwargs): raise e except Exception as e: raise e + + +def get_torch_dtype(device, no_half: bool): + device = str(device) + use_fp16 = not no_half + use_gpu = device == "cuda" + if device in ["cuda", "mps"] and use_fp16: + return use_gpu, torch.float16 + return use_gpu, torch.float32