add back local_files_only to from_pretrained
This commit is contained in:
parent
8dd3a06945
commit
84c2b515c8
@ -359,6 +359,7 @@ class Api:
|
||||
low_mem=self.config.low_mem,
|
||||
disable_nsfw=self.config.disable_nsfw_checker,
|
||||
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||
local_files_only=self.config.local_files_only,
|
||||
cpu_offload=self.config.cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
@ -176,6 +176,7 @@ def start(
|
||||
low_mem=low_mem,
|
||||
cpu_offload=cpu_offload,
|
||||
disable_nsfw_checker=disable_nsfw_checker,
|
||||
local_files_only=local_files_only,
|
||||
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
||||
device=device,
|
||||
gui=gui,
|
||||
|
@ -18,6 +18,7 @@ from .utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
|
||||
|
||||
@ -47,7 +48,12 @@ class ControlNet(DiffusionInpaintModel):
|
||||
self.model_info = model_info
|
||||
self.controlnet_method = controlnet_method
|
||||
|
||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||
model_kwargs = {
|
||||
**kwargs.get("pipe_components", {}),
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
self.local_files_only = model_kwargs["local_files_only"]
|
||||
|
||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||
"cpu_offload", False
|
||||
)
|
||||
@ -82,6 +88,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
pretrained_model_name_or_path=controlnet_method,
|
||||
resume_download=True,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
if model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
@ -124,7 +131,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
def switch_controlnet_method(self, new_method: str):
|
||||
self.controlnet_method = new_method
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
new_method, resume_download=True
|
||||
new_method, resume_download=True, local_files_only=self.local_files_only
|
||||
).to(self.model.device)
|
||||
self.model.controlnet = controlnet
|
||||
|
||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
||||
from iopaint.const import INSTRUCT_PIX2PIX_NAME
|
||||
from .base import DiffusionInpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .utils import get_torch_dtype, enable_low_mem
|
||||
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||
|
||||
|
||||
class InstructPix2Pix(DiffusionInpaintModel):
|
||||
@ -19,7 +19,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
|
||||
model_kwargs = {}
|
||||
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
from iopaint.const import KANDINSKY22_NAME
|
||||
from .base import DiffusionInpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .utils import get_torch_dtype, enable_low_mem
|
||||
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||
|
||||
|
||||
class Kandinsky(DiffusionInpaintModel):
|
||||
@ -20,8 +20,8 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch_dtype,
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
self.name, **model_kwargs
|
||||
).to(device)
|
||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
||||
from iopaint.helper import decode_base64_to_image
|
||||
from .base import DiffusionInpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .utils import get_torch_dtype, enable_low_mem
|
||||
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||
|
||||
|
||||
class PaintByExample(DiffusionInpaintModel):
|
||||
@ -19,7 +19,9 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
model_kwargs = {}
|
||||
model_kwargs = {
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Paint By Example Model NSFW checker")
|
||||
|
@ -6,7 +6,12 @@ from loguru import logger
|
||||
|
||||
from ..base import DiffusionInpaintModel
|
||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from ..utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||
from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .powerpaint_tokenizer import add_task_to_prompt
|
||||
from ...const import POWERPAINT_NAME
|
||||
@ -23,7 +28,7 @@ class PowerPaint(DiffusionInpaintModel):
|
||||
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
model_kwargs = {}
|
||||
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
|
@ -5,7 +5,12 @@ from loguru import logger
|
||||
|
||||
from .base import DiffusionInpaintModel
|
||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||
from .utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
|
||||
|
||||
@ -19,8 +24,13 @@ class SD(DiffusionInpaintModel):
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
|
||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False)
|
||||
model_kwargs = {
|
||||
**kwargs.get("pipe_components", {}),
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||
"cpu_offload", False
|
||||
)
|
||||
if disable_nsfw_checker:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
|
@ -10,7 +10,12 @@ from iopaint.schema import InpaintRequest, ModelType
|
||||
|
||||
from .base import DiffusionInpaintModel
|
||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||
from .utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
|
||||
|
||||
class SDXL(DiffusionInpaintModel):
|
||||
@ -35,10 +40,13 @@ class SDXL(DiffusionInpaintModel):
|
||||
self.model_id_or_path,
|
||||
dtype=torch_dtype,
|
||||
num_in_channels=num_in_channels,
|
||||
load_safety_checker=False
|
||||
load_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||
model_kwargs = {
|
||||
**kwargs.get("pipe_components", {}),
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
if "vae" not in model_kwargs:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||
|
@ -971,6 +971,12 @@ def get_scheduler(sd_sampler, scheduler_config):
|
||||
raise ValueError(sd_sampler)
|
||||
|
||||
|
||||
def is_local_files_only(**kwargs) -> bool:
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
|
||||
|
||||
|
||||
def handle_from_pretrained_exceptions(func, **kwargs):
|
||||
try:
|
||||
return func(**kwargs)
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
from iopaint.download import scan_models
|
||||
from iopaint.helper import switch_mps_device
|
||||
from iopaint.model import models, ControlNet, SD, SDXL
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||
from iopaint.model_info import ModelInfo, ModelType
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
@ -182,7 +182,11 @@ class ModelManager:
|
||||
lcm_lora_loaded = bool(self.model.model.get_list_adapters())
|
||||
if config.sd_lcm_lora:
|
||||
if not lcm_lora_loaded:
|
||||
self.model.model.load_lora_weights(self.model.lcm_lora_id, weight_name="pytorch_lora_weights.safetensors")
|
||||
self.model.model.load_lora_weights(
|
||||
self.model.lcm_lora_id,
|
||||
weight_name="pytorch_lora_weights.safetensors",
|
||||
local_files_only=is_local_files_only(),
|
||||
)
|
||||
else:
|
||||
if lcm_lora_loaded:
|
||||
self.model.model.disable_lora()
|
||||
|
@ -90,6 +90,7 @@ class ApiConfig(BaseModel):
|
||||
low_mem: bool
|
||||
cpu_offload: bool
|
||||
disable_nsfw_checker: bool
|
||||
local_files_only: bool
|
||||
cpu_textencoder: bool
|
||||
device: Device
|
||||
gui: bool
|
||||
|
Loading…
Reference in New Issue
Block a user