add sd-cpu-textencoder args
This commit is contained in:
parent
0d57e552cf
commit
ec7b2d8e2d
@ -71,8 +71,7 @@ class SD(InpaintModel):
|
|||||||
from .sd_pipeline import StableDiffusionInpaintPipeline
|
from .sd_pipeline import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
sd_disable_nsfw = kwargs.pop('sd_disable_nsfw', False)
|
if kwargs['sd_disable_nsfw']:
|
||||||
if sd_disable_nsfw:
|
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(dict(
|
model_kwargs.update(dict(
|
||||||
feature_extractor=DummyFeatureExtractor(),
|
feature_extractor=DummyFeatureExtractor(),
|
||||||
@ -89,6 +88,11 @@ class SD(InpaintModel):
|
|||||||
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
self.model.enable_attention_slicing()
|
self.model.enable_attention_slicing()
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
if kwargs['sd_cpu_textencoder']:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'))
|
||||||
|
|
||||||
self.callbacks = kwargs.pop("callbacks", None)
|
self.callbacks = kwargs.pop("callbacks", None)
|
||||||
|
|
||||||
@torch.cuda.amp.autocast()
|
@torch.cuda.amp.autocast()
|
||||||
|
@ -15,12 +15,17 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf_access_token",
|
"--hf_access_token",
|
||||||
default="",
|
default="",
|
||||||
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
|
help="Huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-disable-nsfw",
|
"--sd-disable-nsfw",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="disable stable diffusion nsfw checker",
|
help="Disable Stable Diffusion nsfw checker",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sd-cpu-textencoder",
|
||||||
|
action="store_true",
|
||||||
|
help="Always run Stable Diffusion TextEncoder model on CPU",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
||||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||||
|
@ -219,6 +219,7 @@ def main(args):
|
|||||||
device=device,
|
device=device,
|
||||||
hf_access_token=args.hf_access_token,
|
hf_access_token=args.hf_access_token,
|
||||||
sd_disable_nsfw=args.sd_disable_nsfw,
|
sd_disable_nsfw=args.sd_disable_nsfw,
|
||||||
|
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||||
callbacks=[diffuser_callback],
|
callbacks=[diffuser_callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user