add --no-half arg

This commit is contained in:
Qing 2023-01-03 21:30:33 +08:00
parent 6cfc7c30f1
commit 59ee89bd34
5 changed files with 6 additions and 3 deletions

View File

@ -74,7 +74,6 @@ const Toast = React.forwardRef<
}) })
Toast.defaultProps = { Toast.defaultProps = {
desc: '',
state: 'loading', state: 'loading',
} }

View File

@ -15,8 +15,9 @@ class PaintByExample(InpaintModel):
min_size = 512 min_size = 512
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs['no_half']
use_gpu = device == torch.device('cuda') and torch.cuda.is_available() use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = DiffusionPipeline.from_pretrained( self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example", "Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,

View File

@ -30,6 +30,7 @@ class SD(InpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
fp16 = not kwargs['no_half']
model_kwargs = {"local_files_only": kwargs['sd_run_local']} model_kwargs = {"local_files_only": kwargs['sd_run_local']}
if kwargs['sd_disable_nsfw']: if kwargs['sd_disable_nsfw']:
@ -42,7 +43,7 @@ class SD(InpaintModel):
torch_dtype = torch.float16 if use_gpu else torch.float32 torch_dtype = torch.float16 if use_gpu else torch.float32
self.model = StableDiffusionInpaintPipeline.from_pretrained( self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path, self.model_id_or_path,
revision="fp16" if use_gpu else "main", revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs **model_kwargs

View File

@ -12,6 +12,7 @@ def parse_args():
default="lama", default="lama",
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"], choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
) )
parser.add_argument("--no-half", action="store_true", help="SD/PaintByExample model no half precision")
parser.add_argument( parser.add_argument(
"--hf_access_token", "--hf_access_token",
default="", default="",

View File

@ -303,6 +303,7 @@ def main(args):
model = ModelManager( model = ModelManager(
name=args.model, name=args.model,
device=device, device=device,
no_half=args.no_half,
hf_access_token=args.hf_access_token, hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw, sd_disable_nsfw=args.sd_disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder, sd_cpu_textencoder=args.sd_cpu_textencoder,