add back local_files_only to from_pretrained
This commit is contained in:
parent
8dd3a06945
commit
84c2b515c8
@ -359,6 +359,7 @@ class Api:
|
|||||||
low_mem=self.config.low_mem,
|
low_mem=self.config.low_mem,
|
||||||
disable_nsfw=self.config.disable_nsfw_checker,
|
disable_nsfw=self.config.disable_nsfw_checker,
|
||||||
sd_cpu_textencoder=self.config.cpu_textencoder,
|
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||||
|
local_files_only=self.config.local_files_only,
|
||||||
cpu_offload=self.config.cpu_offload,
|
cpu_offload=self.config.cpu_offload,
|
||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
)
|
)
|
||||||
|
@ -176,6 +176,7 @@ def start(
|
|||||||
low_mem=low_mem,
|
low_mem=low_mem,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
disable_nsfw_checker=disable_nsfw_checker,
|
disable_nsfw_checker=disable_nsfw_checker,
|
||||||
|
local_files_only=local_files_only,
|
||||||
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
||||||
device=device,
|
device=device,
|
||||||
gui=gui,
|
gui=gui,
|
||||||
|
@ -18,6 +18,7 @@ from .utils import (
|
|||||||
handle_from_pretrained_exceptions,
|
handle_from_pretrained_exceptions,
|
||||||
get_torch_dtype,
|
get_torch_dtype,
|
||||||
enable_low_mem,
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +48,12 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
self.model_info = model_info
|
self.model_info = model_info
|
||||||
self.controlnet_method = controlnet_method
|
self.controlnet_method = controlnet_method
|
||||||
|
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {
|
||||||
|
**kwargs.get("pipe_components", {}),
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
|
self.local_files_only = model_kwargs["local_files_only"]
|
||||||
|
|
||||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||||
"cpu_offload", False
|
"cpu_offload", False
|
||||||
)
|
)
|
||||||
@ -82,6 +88,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
controlnet = ControlNetModel.from_pretrained(
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
pretrained_model_name_or_path=controlnet_method,
|
pretrained_model_name_or_path=controlnet_method,
|
||||||
resume_download=True,
|
resume_download=True,
|
||||||
|
local_files_only=model_kwargs["local_files_only"],
|
||||||
)
|
)
|
||||||
if model_info.is_single_file_diffusers:
|
if model_info.is_single_file_diffusers:
|
||||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
@ -124,7 +131,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
def switch_controlnet_method(self, new_method: str):
|
def switch_controlnet_method(self, new_method: str):
|
||||||
self.controlnet_method = new_method
|
self.controlnet_method = new_method
|
||||||
controlnet = ControlNetModel.from_pretrained(
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
new_method, resume_download=True
|
new_method, resume_download=True, local_files_only=self.local_files_only
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
self.model.controlnet = controlnet
|
self.model.controlnet = controlnet
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
from iopaint.const import INSTRUCT_PIX2PIX_NAME
|
from iopaint.const import INSTRUCT_PIX2PIX_NAME
|
||||||
from .base import DiffusionInpaintModel
|
from .base import DiffusionInpaintModel
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest
|
||||||
from .utils import get_torch_dtype, enable_low_mem
|
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||||
|
|
||||||
|
|
||||||
class InstructPix2Pix(DiffusionInpaintModel):
|
class InstructPix2Pix(DiffusionInpaintModel):
|
||||||
@ -19,7 +19,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
|||||||
|
|
||||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(
|
model_kwargs.update(
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
from iopaint.const import KANDINSKY22_NAME
|
from iopaint.const import KANDINSKY22_NAME
|
||||||
from .base import DiffusionInpaintModel
|
from .base import DiffusionInpaintModel
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest
|
||||||
from .utils import get_torch_dtype, enable_low_mem
|
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||||
|
|
||||||
|
|
||||||
class Kandinsky(DiffusionInpaintModel):
|
class Kandinsky(DiffusionInpaintModel):
|
||||||
@ -20,8 +20,8 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"torch_dtype": torch_dtype,
|
"torch_dtype": torch_dtype,
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||||
self.name, **model_kwargs
|
self.name, **model_kwargs
|
||||||
).to(device)
|
).to(device)
|
||||||
|
@ -7,7 +7,7 @@ from loguru import logger
|
|||||||
from iopaint.helper import decode_base64_to_image
|
from iopaint.helper import decode_base64_to_image
|
||||||
from .base import DiffusionInpaintModel
|
from .base import DiffusionInpaintModel
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest
|
||||||
from .utils import get_torch_dtype, enable_low_mem
|
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||||
|
|
||||||
|
|
||||||
class PaintByExample(DiffusionInpaintModel):
|
class PaintByExample(DiffusionInpaintModel):
|
||||||
@ -19,7 +19,9 @@ class PaintByExample(DiffusionInpaintModel):
|
|||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
model_kwargs = {}
|
model_kwargs = {
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
logger.info("Disable Paint By Example Model NSFW checker")
|
logger.info("Disable Paint By Example Model NSFW checker")
|
||||||
|
@ -6,7 +6,12 @@ from loguru import logger
|
|||||||
|
|
||||||
from ..base import DiffusionInpaintModel
|
from ..base import DiffusionInpaintModel
|
||||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from ..utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
from ..utils import (
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
)
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest
|
||||||
from .powerpaint_tokenizer import add_task_to_prompt
|
from .powerpaint_tokenizer import add_task_to_prompt
|
||||||
from ...const import POWERPAINT_NAME
|
from ...const import POWERPAINT_NAME
|
||||||
@ -23,7 +28,7 @@ class PowerPaint(DiffusionInpaintModel):
|
|||||||
from .powerpaint_tokenizer import PowerPaintTokenizer
|
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||||
|
|
||||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
model_kwargs = {}
|
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(
|
model_kwargs.update(
|
||||||
|
@ -5,7 +5,12 @@ from loguru import logger
|
|||||||
|
|
||||||
from .base import DiffusionInpaintModel
|
from .base import DiffusionInpaintModel
|
||||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
from .utils import (
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
)
|
||||||
from iopaint.schema import InpaintRequest, ModelType
|
from iopaint.schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
@ -19,8 +24,13 @@ class SD(DiffusionInpaintModel):
|
|||||||
|
|
||||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {
|
||||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False)
|
**kwargs.get("pipe_components", {}),
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||||
|
"cpu_offload", False
|
||||||
|
)
|
||||||
if disable_nsfw_checker:
|
if disable_nsfw_checker:
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(
|
model_kwargs.update(
|
||||||
|
@ -10,7 +10,12 @@ from iopaint.schema import InpaintRequest, ModelType
|
|||||||
|
|
||||||
from .base import DiffusionInpaintModel
|
from .base import DiffusionInpaintModel
|
||||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
from .utils import (
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXL(DiffusionInpaintModel):
|
class SDXL(DiffusionInpaintModel):
|
||||||
@ -35,10 +40,13 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
num_in_channels=num_in_channels,
|
num_in_channels=num_in_channels,
|
||||||
load_safety_checker=False
|
load_safety_checker=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {
|
||||||
|
**kwargs.get("pipe_components", {}),
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
if "vae" not in model_kwargs:
|
if "vae" not in model_kwargs:
|
||||||
vae = AutoencoderKL.from_pretrained(
|
vae = AutoencoderKL.from_pretrained(
|
||||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||||
|
@ -971,6 +971,12 @@ def get_scheduler(sd_sampler, scheduler_config):
|
|||||||
raise ValueError(sd_sampler)
|
raise ValueError(sd_sampler)
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_files_only(**kwargs) -> bool:
|
||||||
|
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||||
|
|
||||||
|
return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
|
||||||
|
|
||||||
|
|
||||||
def handle_from_pretrained_exceptions(func, **kwargs):
|
def handle_from_pretrained_exceptions(func, **kwargs):
|
||||||
try:
|
try:
|
||||||
return func(**kwargs)
|
return func(**kwargs)
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
from iopaint.download import scan_models
|
from iopaint.download import scan_models
|
||||||
from iopaint.helper import switch_mps_device
|
from iopaint.helper import switch_mps_device
|
||||||
from iopaint.model import models, ControlNet, SD, SDXL
|
from iopaint.model import models, ControlNet, SD, SDXL
|
||||||
from iopaint.model.utils import torch_gc
|
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||||
from iopaint.model_info import ModelInfo, ModelType
|
from iopaint.model_info import ModelInfo, ModelType
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest
|
||||||
|
|
||||||
@ -182,7 +182,11 @@ class ModelManager:
|
|||||||
lcm_lora_loaded = bool(self.model.model.get_list_adapters())
|
lcm_lora_loaded = bool(self.model.model.get_list_adapters())
|
||||||
if config.sd_lcm_lora:
|
if config.sd_lcm_lora:
|
||||||
if not lcm_lora_loaded:
|
if not lcm_lora_loaded:
|
||||||
self.model.model.load_lora_weights(self.model.lcm_lora_id, weight_name="pytorch_lora_weights.safetensors")
|
self.model.model.load_lora_weights(
|
||||||
|
self.model.lcm_lora_id,
|
||||||
|
weight_name="pytorch_lora_weights.safetensors",
|
||||||
|
local_files_only=is_local_files_only(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if lcm_lora_loaded:
|
if lcm_lora_loaded:
|
||||||
self.model.model.disable_lora()
|
self.model.model.disable_lora()
|
||||||
|
@ -90,6 +90,7 @@ class ApiConfig(BaseModel):
|
|||||||
low_mem: bool
|
low_mem: bool
|
||||||
cpu_offload: bool
|
cpu_offload: bool
|
||||||
disable_nsfw_checker: bool
|
disable_nsfw_checker: bool
|
||||||
|
local_files_only: bool
|
||||||
cpu_textencoder: bool
|
cpu_textencoder: bool
|
||||||
device: Device
|
device: Device
|
||||||
gui: bool
|
gui: bool
|
||||||
|
Loading…
Reference in New Issue
Block a user