add sd-disable-nsfw arg

This commit is contained in:
Qing 2022-09-29 09:42:19 +08:00
parent 1a92569f00
commit 0d57e552cf
3 changed files with 41 additions and 1 deletions

View File

@ -6,6 +6,7 @@ import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler
from loguru import logger
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
from lama_cleaner.helper import norm_img
@ -38,6 +39,29 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask)
# return mask
class DummyFeatureExtractorOutput:
def __init__(self, pixel_values):
self.pixel_values = pixel_values
def to(self, device):
return self
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, *args, **kwargs):
return DummyFeatureExtractorOutput(torch.empty(0, 3))
class DummySafetyChecker:
def __init__(self, *args, **kwargs):
pass
def __call__(self, clip_input, images):
return images, False
class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
@ -46,11 +70,21 @@ class SD(InpaintModel):
def init_model(self, device: torch.device, **kwargs):
from .sd_pipeline import StableDiffusionInpaintPipeline
model_kwargs = {}
sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False)
if sd_disable_nsfw:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
feature_extractor=DummyFeatureExtractor(),
safety_checker=DummySafetyChecker(),
))
self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path,
revision="fp16" if torch.cuda.is_available() else "main",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=kwargs["hf_access_token"],
**model_kwargs
)
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()

View File

@ -17,6 +17,11 @@ def parse_args():
default="",
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
)
parser.add_argument(
"--sd-disable-nsfw",
action="store_true",
help="disable stable diffusion nsfw checker",
)
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(

View File

@ -218,6 +218,7 @@ def main(args):
name=args.model,
device=device,
hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw,
callbacks=[diffuser_callback],
)