add sd-cpu-textencoder args

This commit is contained in:
Qing 2022-09-29 12:20:55 +08:00
parent 0d57e552cf
commit ec7b2d8e2d
3 changed files with 14 additions and 4 deletions

View File

@ -71,8 +71,7 @@ class SD(InpaintModel):
from .sd_pipeline import StableDiffusionInpaintPipeline from .sd_pipeline import StableDiffusionInpaintPipeline
model_kwargs = {} model_kwargs = {}
sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False) if kwargs['sd_disable_nsfw']:
if sd_disable_nsfw:
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict( model_kwargs.update(dict(
feature_extractor=DummyFeatureExtractor(), feature_extractor=DummyFeatureExtractor(),
@ -89,6 +88,11 @@ class SD(InpaintModel):
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()
self.model = self.model.to(device) self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'))
self.callbacks = kwargs.pop("callbacks", None) self.callbacks = kwargs.pop("callbacks", None)
@torch.cuda.amp.autocast() @torch.cuda.amp.autocast()

View File

@ -15,12 +15,17 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--hf_access_token", "--hf_access_token",
default="", default="",
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens", help="Huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
) )
parser.add_argument( parser.add_argument(
"--sd-disable-nsfw", "--sd-disable-nsfw",
action="store_true", action="store_true",
help="disable stable diffusion nsfw checker", help="Disable Stable Diffusion nsfw checker",
)
parser.add_argument(
"--sd-cpu-textencoder",
action="store_true",
help="Always run Stable Diffusion TextEncoder model on CPU",
) )
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")

View File

@ -219,6 +219,7 @@ def main(args):
device=device, device=device,
hf_access_token=args.hf_access_token, hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw, sd_disable_nsfw=args.sd_disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
callbacks=[diffuser_callback], callbacks=[diffuser_callback],
) )