2023-03-29 16:05:34 +02:00
import os
2024-01-02 07:34:36 +01:00
from loguru import logger
2023-12-28 03:48:52 +01:00
from lama_cleaner . tests . utils import check_device , get_config , assert_equal
2023-03-29 16:05:34 +02:00
os . environ [ " PYTORCH_ENABLE_MPS_FALLBACK " ] = " 1 "
2022-11-15 14:09:51 +01:00
from pathlib import Path
import pytest
import torch
from lama_cleaner . model_manager import ModelManager
2023-12-19 06:16:30 +01:00
from lama_cleaner . schema import HDStrategy , SDSampler , FREEUConfig
2022-11-15 14:09:51 +01:00
current_dir = Path ( __file__ ) . parent . absolute ( ) . resolve ( )
2023-02-07 14:00:19 +01:00
save_dir = current_dir / " result "
2022-11-15 14:09:51 +01:00
save_dir . mkdir ( exist_ok = True , parents = True )
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " ] )
2024-01-02 07:34:36 +01:00
def test_runway_sd_1_5_all_samplers ( device ) :
2023-12-28 03:48:52 +01:00
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
)
2022-11-15 14:09:51 +01:00
2024-01-02 07:34:36 +01:00
all_samplers = [ member . value for member in SDSampler . __members__ . values ( ) ]
print ( all_samplers )
for sampler in all_samplers :
print ( f " Testing sampler { sampler } " )
if (
sampler
in [ SDSampler . dpm2_karras , SDSampler . dpm2_a_karras , SDSampler . lms_karras ]
and device == " mps "
) :
# diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it
logger . warning (
" skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn ' t support float64. Please use float32 instead. "
)
continue
cfg = get_config (
strategy = HDStrategy . ORIGINAL ,
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
sd_sampler = sampler ,
)
name = f " device_ { device } _ { sampler } "
assert_equal (
model ,
cfg ,
f " runway_sd_ { name } .png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2022-11-15 14:09:51 +01:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . lcm ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_lcm_lora ( device , sampler ) :
check_device ( device )
2022-11-15 14:09:51 +01:00
2023-12-19 06:16:30 +01:00
sd_steps = 5
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
)
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = HDStrategy . ORIGINAL ,
2023-12-19 06:16:30 +01:00
prompt = " face of a fox, sitting on a bench " ,
sd_steps = sd_steps ,
sd_guidance_scale = 2 ,
sd_lcm_lora = True ,
2023-02-07 14:00:19 +01:00
)
2022-11-15 14:09:51 +01:00
cfg . sd_sampler = sampler
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_sd_1_5_lcm_lora_device_ { device } .png " ,
2022-11-15 14:09:51 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2022-11-15 14:09:51 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_freeu ( device , sampler ) :
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
2023-02-07 14:00:19 +01:00
sd_cpu_textencoder = False ,
)
2022-11-15 14:09:51 +01:00
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = HDStrategy . ORIGINAL ,
2023-12-19 06:16:30 +01:00
prompt = " face of a fox, sitting on a bench " ,
2022-11-15 14:09:51 +01:00
sd_steps = sd_steps ,
2023-12-19 06:16:30 +01:00
sd_guidance_scale = 7.5 ,
sd_freeu = True ,
sd_freeu_config = FREEUConfig ( ) ,
2022-11-15 14:09:51 +01:00
)
2023-12-19 06:16:30 +01:00
cfg . sd_sampler = sampler
2022-11-15 14:09:51 +01:00
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_sd_1_5_freeu_device_ { device } .png " ,
2022-11-15 14:09:51 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-01-05 15:07:39 +01:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " ] )
2023-01-05 15:07:39 +01:00
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_sd_strength ( device , strategy , sampler ) :
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-12-19 06:16:30 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
2023-02-07 14:00:19 +01:00
)
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = strategy ,
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
sd_strength = 0.8 ,
2023-02-07 14:00:19 +01:00
)
2023-01-05 15:07:39 +01:00
cfg . sd_sampler = sampler
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_sd_strength_0.8_device_ { device } .png " ,
2023-01-05 15:07:39 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2023-11-15 10:20:44 +01:00
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-12-28 03:48:52 +01:00
def test_runway_norm_sd_model ( device , strategy , sampler ) :
sd_steps = check_device ( device )
2023-11-15 10:20:44 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-v1-5 " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-11-15 10:20:44 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
)
2023-12-28 03:48:52 +01:00
cfg = get_config (
strategy = strategy , prompt = " face of a fox, sitting on a bench " , sd_steps = sd_steps
)
2023-11-15 10:20:44 +01:00
cfg . sd_sampler = sampler
assert_equal (
model ,
cfg ,
2023-12-28 03:48:52 +01:00
f " runway_ { device } _norm_sd_model_device_ { device } .png " ,
2023-11-15 10:20:44 +01:00
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " ] )
2023-01-05 15:07:39 +01:00
@pytest.mark.parametrize ( " strategy " , [ HDStrategy . ORIGINAL ] )
2024-01-02 07:34:36 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . dpm_plus_plus_2m ] )
2023-12-28 03:48:52 +01:00
def test_runway_sd_1_5_cpu_offload ( device , strategy , sampler ) :
sd_steps = check_device ( device )
2023-02-07 14:00:19 +01:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = " runwayml/stable-diffusion-inpainting " ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-02-07 14:00:19 +01:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
cpu_offload = True ,
)
2023-12-28 03:48:52 +01:00
cfg = get_config (
strategy = strategy , prompt = " a fox sitting on a bench " , sd_steps = sd_steps
)
2023-01-05 15:07:39 +01:00
cfg . sd_sampler = sampler
2023-12-28 03:48:52 +01:00
name = f " device_ { device } _ { sampler } "
2023-01-05 15:07:39 +01:00
assert_equal (
model ,
cfg ,
f " runway_sd_ { strategy . capitalize ( ) } _ { name } _cpu_offload.png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)
2023-01-07 01:52:11 +01:00
2023-03-29 16:05:34 +02:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize ( " device " , [ " cuda " , " mps " , " cpu " ] )
2023-12-19 06:16:30 +01:00
@pytest.mark.parametrize ( " sampler " , [ SDSampler . ddim ] )
2023-11-16 07:09:08 +01:00
@pytest.mark.parametrize (
2023-12-19 06:16:30 +01:00
" name " ,
2023-11-16 07:09:08 +01:00
[
2023-12-19 06:16:30 +01:00
" sd-v1-5-inpainting.ckpt " ,
" sd-v1-5-inpainting.safetensors " ,
" v1-5-pruned-emaonly.safetensors " ,
2023-11-16 07:09:08 +01:00
] ,
)
2023-12-28 03:48:52 +01:00
def test_local_file_path ( device , sampler , name ) :
sd_steps = check_device ( device )
2023-03-29 16:05:34 +02:00
model = ModelManager (
2023-12-19 06:16:30 +01:00
name = name ,
2023-12-28 03:48:52 +01:00
device = torch . device ( device ) ,
2023-03-29 16:05:34 +02:00
disable_nsfw = True ,
sd_cpu_textencoder = False ,
2023-12-19 06:16:30 +01:00
cpu_offload = False ,
2023-03-29 16:05:34 +02:00
)
cfg = get_config (
2023-12-28 03:48:52 +01:00
strategy = HDStrategy . ORIGINAL ,
2023-03-29 16:05:34 +02:00
prompt = " a fox sitting on a bench " ,
sd_steps = sd_steps ,
)
cfg . sd_sampler = sampler
2023-12-28 03:48:52 +01:00
name = f " device_ { device } _ { sampler } _ { name } "
2023-03-29 16:05:34 +02:00
assert_equal (
model ,
cfg ,
f " sd_local_model_ { name } .png " ,
img_p = current_dir / " overture-creations-5sI6fQgYIuo.png " ,
mask_p = current_dir / " overture-creations-5sI6fQgYIuo_mask.png " ,
)