add sd-disable-nsfw arg

This commit is contained in:
Qing 2022-09-29 09:42:19 +08:00
parent 1a92569f00
commit 0d57e552cf
3 changed files with 41 additions and 1 deletions

View File

@ -6,6 +6,7 @@ import numpy as np
import torch import torch
from diffusers import PNDMScheduler, DDIMScheduler from diffusers import PNDMScheduler, DDIMScheduler
from loguru import logger from loguru import logger
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
from lama_cleaner.helper import norm_img from lama_cleaner.helper import norm_img
@ -38,19 +39,52 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask) # mask = torch.from_numpy(mask)
# return mask # return mask
class DummyFeatureExtractorOutput:
def __init__(self, pixel_values):
self.pixel_values = pixel_values
def to(self, device):
return self
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, *args, **kwargs):
return DummyFeatureExtractorOutput(torch.empty(0, 3))
class DummySafetyChecker:
def __init__(self, *args, **kwargs):
pass
def __call__(self, clip_input, images):
return images, False
class SD(InpaintModel): class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
min_size = 512 min_size = 512
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from .sd_pipeline import StableDiffusionInpaintPipeline from .sd_pipeline import StableDiffusionInpaintPipeline
model_kwargs = {}
sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False)
if sd_disable_nsfw:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
feature_extractor=DummyFeatureExtractor(),
safety_checker=DummySafetyChecker(),
))
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16" if torch.cuda.is_available() else "main", revision="fp16" if torch.cuda.is_available() else "main",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs
) )
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()

View File

@ -17,6 +17,11 @@ def parse_args():
default="", default="",
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
) )
parser.add_argument(
"--sd-disable-nsfw",
action="store_true",
help="disable stable diffusion nsfw checker",
)
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument( parser.add_argument(

View File

@ -218,6 +218,7 @@ def main(args):
name=args.model, name=args.model,
device=device, device=device,
hf_access_token=args.hf_access_token, hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw,
callbacks=[diffuser_callback], callbacks=[diffuser_callback],
) )