better handle scan single file diffusion
This commit is contained in:
parent
0a262fa811
commit
76823355fe
@ -68,7 +68,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType:
|
|||||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
model_type = ModelType.DIFFUSERS_SD
|
model_type = ModelType.DIFFUSERS_SD
|
||||||
else:
|
else:
|
||||||
raise e
|
logger.info(f"Ignore non sd or sdxl file: {model_abs_path}")
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
|||||||
if "but got torch.Size([320, 4, 3, 3])" in str(e):
|
if "but got torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
else:
|
else:
|
||||||
raise e
|
logger.info(f"Ignore non sd or sdxl file: {model_abs_path}")
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class BrushNetWrapper(DiffusionInpaintModel):
|
|||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
load_safety_checker=not disable_nsfw_checker,
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
config_files=get_config_files(),
|
original_config_file=get_config_files()['v1'],
|
||||||
brushnet=brushnet,
|
brushnet=brushnet,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -42,7 +42,7 @@ def test_runway_brushnet(device, sampler):
|
|||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"brushnet_runway_1_5_freeu_device_{device}.png",
|
f"brushnet_random_mask_{device}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
)
|
)
|
||||||
@ -74,16 +74,13 @@ def test_brushnet_local_file_path(device, sampler, name):
|
|||||||
brushnet_method=SD_BRUSHNET_CHOICES[1]
|
brushnet_method=SD_BRUSHNET_CHOICES[1]
|
||||||
)
|
)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
name = f"device_{device}_{sampler}_{name}"
|
|
||||||
|
|
||||||
is_sdxl = "sd_xl" in name
|
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"brushnet_sd_local_model_{name}.png",
|
f"brushnet_segmentation_mask_{device}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.5 if is_sdxl else 1,
|
fx=1,
|
||||||
fy=1.5 if is_sdxl else 1,
|
fy=1,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user