better handle scan single file diffusion

This commit is contained in:
Qing 2024-04-12 18:52:33 +08:00
parent 0a262fa811
commit 76823355fe
3 changed files with 7 additions and 10 deletions

View File

@ -68,7 +68,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType:
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
model_type = ModelType.DIFFUSERS_SD model_type = ModelType.DIFFUSERS_SD
else: else:
raise e logger.info(f"Ignore non sd or sdxl file: {model_abs_path}")
return model_type return model_type
@ -96,7 +96,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
if "but got torch.Size([320, 4, 3, 3])" in str(e): if "but got torch.Size([320, 4, 3, 3])" in str(e):
model_type = ModelType.DIFFUSERS_SDXL model_type = ModelType.DIFFUSERS_SDXL
else: else:
raise e logger.info(f"Ignore non sd or sdxl file: {model_abs_path}")
return model_type return model_type

View File

@ -64,7 +64,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
self.model_id_or_path, self.model_id_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker, load_safety_checker=not disable_nsfw_checker,
config_files=get_config_files(), original_config_file=get_config_files()['v1'],
brushnet=brushnet, brushnet=brushnet,
**model_kwargs, **model_kwargs,
) )

View File

@ -42,7 +42,7 @@ def test_runway_brushnet(device, sampler):
assert_equal( assert_equal(
model, model,
cfg, cfg,
f"brushnet_runway_1_5_freeu_device_{device}.png", f"brushnet_random_mask_{device}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
) )
@ -74,16 +74,13 @@ def test_brushnet_local_file_path(device, sampler, name):
brushnet_method=SD_BRUSHNET_CHOICES[1] brushnet_method=SD_BRUSHNET_CHOICES[1]
) )
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
name = f"device_{device}_{sampler}_{name}"
is_sdxl = "sd_xl" in name
assert_equal( assert_equal(
model, model,
cfg, cfg,
f"brushnet_sd_local_model_{name}.png", f"brushnet_segmentation_mask_{device}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.5 if is_sdxl else 1, fx=1,
fy=1.5 if is_sdxl else 1, fy=1,
) )