2023-03-29 16:05:34 +02:00
import os
2024-01-02 07:34:36 +01:00
from loguru import logger
2024-01-05 08:19:23 +01:00
from iopaint . tests . utils import check_device , get_config , assert_equal
2023-12-28 03:48:52 +01:00
2023-03-29 16:05:34 +02:00
os . environ [ " PYTORCH_ENABLE_MPS_FALLBACK " ] = " 1 "
2022-11-15 14:09:51 +01:00
from pathlib import Path
import pytest
import torch
2024-01-05 08:19:23 +01:00
from iopaint . model_manager import ModelManager
2024-04-25 16:21:33 +02:00
from iopaint . schema import HDStrategy , SDSampler
2022-11-15 14:09:51 +01:00
current_dir = Path ( __file__ ) . parent . absolute ( ) . resolve ( )
2023-02-07 14:00:19 +01:00
save_dir = current_dir / " result "
2022-11-15 14:09:51 +01:00
save_dir . mkdir ( exist_ok = True , parents = True )
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " ] )
2024-01-02 07:34:36 +01:00
def test_runway_sd_1_5_all_samplers ( device ) :
2023-12-28 03:48:52 +01:00
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
)
2022-11-15 14:09:51 +01:00
2024-01-02 07:34:36 +01:00
all_samplers = [ member . value for member in SDSampler . __members__ . values ( ) ]
print ( all_samplers )
for sampler in all_samplers :
print ( f " Testing sampler { sampler } " )
if (
sampler
in [ SDSampler . dpm2_karras , SDSampler . dpm2_a_karras , SDSampler . lms_karras ]
and device == " mps "
) :
# diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it
logger . warning (
" skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn ' t support float64. Please use float32 instead. "
)
continue
cfg = get_config (
strategy = HDStrategy . ORIGINAL ,
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
sd_sampler = sampler ,
)
name = f " device_ { device } _ { sampler } "
assert_equal (
model ,
cfg ,
f " runway_sd_ { name } .png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2022-11-15 14:09:51 +01:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . lcm ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_lcm_lora ( device , sampler ) :
check_device ( device )
2022-11-15 14:09:51 +01:00
2023-12-19 06:16:30 +01:00
sd_steps = 5
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
)
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = HDStrategy . ORIGINAL ,
2023-12-19 06:16:30 +01:00
prompt = " face of a fox, sitting on a bench " ,
sd_steps = sd_steps ,
sd_guidance_scale = 2 ,
sd_lcm_lora = True ,
2023-02-07 14:00:19 +01:00
)
2022-11-15 14:09:51 +01:00
cfg . sd_sampler = sampler
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_sd_1_5_lcm_lora_device_ { device } .png " ,
2022-11-15 14:09:51 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " ] )
2023-01-05 15:07:39 +01:00
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_sd_strength ( device , strategy , sampler ) :
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
2023-02-07 14:00:19 +01:00
)
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = strategy ,
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
sd_strength = 0.8 ,
2023-02-07 14:00:19 +01:00
)
2023-01-05 15:07:39 +01:00
cfg . sd_sampler = sampler
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_sd_strength_0.8_device_ { device } .png " ,
2023-01-05 15:07:39 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2024-01-10 06:34:11 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " cpu " ] )
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
def test_runway_sd_cpu_textencoder ( device , strategy , sampler ) :
sd_steps = check_device ( device )
model = ModelManager (
name = " runwayml/stable-diffusion-inpainting " ,
device = torch . device ( device ) ,
disable_nsfw = True ,
sd_cpu_textencoder = True ,
)
cfg = get_config (
strategy = strategy ,
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
sd_sampler = sampler ,
)
assert_equal (
model ,
cfg ,
f " runway_sd_device_ { device } _cpu_textencoder.png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2023-11-15 10:20:44 +01:00
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-12-28 03:48:52 +01:00
def test_runway_norm_sd_model ( device , strategy , sampler ) :
sd_steps = check_device ( device )
2023-11-15 10:20:44 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-v1-5 " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-11-15 10:20:44 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
)
2023-12-28 03:48:52 +01:00
cfg = get_config (
strategy = strategy , prompt = " face of a fox, sitting on a bench " , sd_steps = sd_steps
)
2023-11-15 10:20:44 +01:00
cfg . sd_sampler = sampler
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_ { device } _norm_sd_model_device_ { device } .png " ,
2023-11-15 10:20:44 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " ] )
2023-01-05 15:07:39 +01:00
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
2024-01-02 07:34:36 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . dpm_plus_plus_2m ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_1_5_cpu_offload ( device , strategy , sampler ) :
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-02-07 14:00:19 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
cpu_offload = True ,
)
2023-12-28 03:48:52 +01:00
cfg = get_config (
strategy = strategy , prompt = " a fox sitting on a bench " , sd_steps = sd_steps
)
2023-01-05 15:07:39 +01:00
cfg . sd_sampler = sampler
2023-12-28 03:48:52 +01:00
name = f " device_ { device } _ { sampler } "
2023-01-05 15:07:39 +01:00
assert_equal (
model ,
cfg ,
f " runway_sd_ { strategy . capitalize ( ) } _ { name } _cpu_offload.png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-01-07 01:52:11 +01:00
2023-03-29 16:05:34 +02:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-11-16 07:09:08 +01:00
@pytest.mark.parametrize (
2023-12-19 06:16:30 +01:00
" name " ,
2023-11-16 07:09:08 +01:00
[
2023-12-19 06:16:30 +01:00
" sd-v1-5-inpainting.safetensors " ,
" v1-5-pruned-emaonly.safetensors " ,
2024-01-30 06:19:13 +01:00
" sd_xl_base_1.0.safetensors " ,
" sd_xl_base_1.0_inpainting_0.1.safetensors " ,
2023-11-16 07:09:08 +01:00
] ,
)
2023-12-28 03:48:52 +01:00
def test_local_file_path ( device , sampler , name ) :
sd_steps = check_device ( device )
2023-03-29 16:05:34 +02:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = name ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-03-29 16:05:34 +02:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
2023-12-19 06:16:30 +01:00
cpu_offload = False ,
2023-03-29 16:05:34 +02:00
)
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = HDStrategy . ORIGINAL ,
2023-03-29 16:05:34 +02:00
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
)
cfg . sd_sampler = sampler
2023-12-28 03:48:52 +01:00
name = f " device_ { device } _ { sampler } _ { name } "
2023-03-29 16:05:34 +02:00
2024-02-04 14:54:16 +01:00
is_sdxl = " sd_xl " in name
2023-03-29 16:05:34 +02:00
assert_equal (
model ,
cfg ,
f " sd_local_model_ { name } .png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
2024-02-04 14:54:16 +01:00
fx = 1.5 if is_sdxl else 1 ,
fy = 1.5 if is_sdxl else 1 ,
2023-03-29 16:05:34 +02:00
)