add sd-disable-nsfw arg
This commit is contained in:
parent
1a92569f00
commit
0d57e552cf
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from diffusers import PNDMScheduler, DDIMScheduler
|
from diffusers import PNDMScheduler, DDIMScheduler
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
|
||||||
|
|
||||||
from lama_cleaner.helper import norm_img
|
from lama_cleaner.helper import norm_img
|
||||||
|
|
||||||
@ -38,19 +39,52 @@ from lama_cleaner.schema import Config, SDSampler
|
|||||||
# mask = torch.from_numpy(mask)
|
# mask = torch.from_numpy(mask)
|
||||||
# return mask
|
# return mask
|
||||||
|
|
||||||
|
class DummyFeatureExtractorOutput:
|
||||||
|
def __init__(self, pixel_values):
|
||||||
|
self.pixel_values = pixel_values
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return DummyFeatureExtractorOutput(torch.empty(0, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class DummySafetyChecker:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, clip_input, images):
|
||||||
|
return images, False
|
||||||
|
|
||||||
|
|
||||||
class SD(InpaintModel):
|
class SD(InpaintModel):
|
||||||
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
|
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
|
||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
from .sd_pipeline import StableDiffusionInpaintPipeline
|
from .sd_pipeline import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
|
model_kwargs = {}
|
||||||
|
sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False)
|
||||||
|
if sd_disable_nsfw:
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(dict(
|
||||||
|
feature_extractor=DummyFeatureExtractor(),
|
||||||
|
safety_checker=DummySafetyChecker(),
|
||||||
|
))
|
||||||
|
|
||||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
revision="fp16" if torch.cuda.is_available() else "main",
|
revision="fp16" if torch.cuda.is_available() else "main",
|
||||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
use_auth_token=kwargs["hf_access_token"],
|
use_auth_token=kwargs["hf_access_token"],
|
||||||
|
**model_kwargs
|
||||||
)
|
)
|
||||||
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
self.model.enable_attention_slicing()
|
self.model.enable_attention_slicing()
|
||||||
|
@ -17,6 +17,11 @@ def parse_args():
|
|||||||
default="",
|
default="",
|
||||||
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
|
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sd-disable-nsfw",
|
||||||
|
action="store_true",
|
||||||
|
help="disable stable diffusion nsfw checker",
|
||||||
|
)
|
||||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
||||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -218,6 +218,7 @@ def main(args):
|
|||||||
name=args.model,
|
name=args.model,
|
||||||
device=device,
|
device=device,
|
||||||
hf_access_token=args.hf_access_token,
|
hf_access_token=args.hf_access_token,
|
||||||
|
sd_disable_nsfw=args.sd_disable_nsfw,
|
||||||
callbacks=[diffuser_callback],
|
callbacks=[diffuser_callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user