Merge branch '1108'
This commit is contained in:
commit
0cfb06ba1a
@ -68,7 +68,7 @@ const SidePanel = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<NumberInputSetting
|
{/* <NumberInputSetting
|
||||||
title="Strength"
|
title="Strength"
|
||||||
width={INPUT_WIDTH}
|
width={INPUT_WIDTH}
|
||||||
allowFloat
|
allowFloat
|
||||||
@ -81,7 +81,7 @@ const SidePanel = () => {
|
|||||||
return { ...old, sdStrength: val }
|
return { ...old, sdStrength: val }
|
||||||
})
|
})
|
||||||
}}
|
}}
|
||||||
/>
|
/> */}
|
||||||
|
|
||||||
<NumberInputSetting
|
<NumberInputSetting
|
||||||
title="Guidance Scale"
|
title="Guidance Scale"
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -9,9 +8,9 @@ import numpy as np
|
|||||||
import nvidia_smi
|
import nvidia_smi
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from lama_cleaner.lama import LaMa
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
@ -21,8 +20,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from lama_cleaner.helper import norm_img
|
|
||||||
|
|
||||||
NUM_THREADS = str(4)
|
NUM_THREADS = str(4)
|
||||||
|
|
||||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||||
@ -37,20 +34,23 @@ if os.environ.get("CACHE_DIR"):
|
|||||||
def run_model(model, size):
|
def run_model(model, size):
|
||||||
# RGB
|
# RGB
|
||||||
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
||||||
image = norm_img(image)
|
|
||||||
|
|
||||||
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
||||||
mask = norm_img(mask)
|
|
||||||
model(image, mask)
|
config = Config(
|
||||||
|
ldm_steps=2,
|
||||||
|
hd_strategy=HDStrategy.ORIGINAL,
|
||||||
|
hd_strategy_crop_margin=128,
|
||||||
|
hd_strategy_crop_trigger_size=128,
|
||||||
|
hd_strategy_resize_limit=128,
|
||||||
|
prompt="a fox is sitting on a bench",
|
||||||
|
sd_steps=5,
|
||||||
|
sd_sampler=SDSampler.ddim
|
||||||
|
)
|
||||||
|
model(image, mask, config)
|
||||||
|
|
||||||
|
|
||||||
def benchmark(model, times: int, empty_cache: bool):
|
def benchmark(model, times: int, empty_cache: bool):
|
||||||
sizes = [
|
sizes = [(512, 512)]
|
||||||
(512, 512),
|
|
||||||
(640, 640),
|
|
||||||
(1080, 800),
|
|
||||||
(2000, 2000)
|
|
||||||
]
|
|
||||||
|
|
||||||
nvidia_smi.nvmlInit()
|
nvidia_smi.nvmlInit()
|
||||||
device_id = 0
|
device_id = 0
|
||||||
@ -71,8 +71,6 @@ def benchmark(model, times: int, empty_cache: bool):
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
run_model(model, size)
|
run_model(model, size)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if empty_cache:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# cpu_metrics.append(process.cpu_percent())
|
# cpu_metrics.append(process.cpu_percent())
|
||||||
time_metrics.append((time.time() - start) * 1000)
|
time_metrics.append((time.time() - start) * 1000)
|
||||||
@ -90,8 +88,9 @@ def benchmark(model, times: int, empty_cache: bool):
|
|||||||
|
|
||||||
def get_args_parser():
|
def get_args_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--name")
|
||||||
parser.add_argument("--device", default="cuda", type=str)
|
parser.add_argument("--device", default="cuda", type=str)
|
||||||
parser.add_argument("--times", default=20, type=int)
|
parser.add_argument("--times", default=10, type=int)
|
||||||
parser.add_argument("--empty-cache", action="store_true")
|
parser.add_argument("--empty-cache", action="store_true")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -99,5 +98,12 @@ def get_args_parser():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args_parser()
|
args = get_args_parser()
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
model = LaMa(device)
|
model = ModelManager(
|
||||||
|
name=args.name,
|
||||||
|
device=device,
|
||||||
|
sd_run_local=True,
|
||||||
|
sd_disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=True,
|
||||||
|
hf_access_token="123"
|
||||||
|
)
|
||||||
benchmark(model, args.times, args.empty_cache)
|
benchmark(model, args.times, args.empty_cache)
|
||||||
|
@ -37,13 +37,14 @@ from lama_cleaner.schema import Config, SDSampler
|
|||||||
# return mask
|
# return mask
|
||||||
|
|
||||||
class CPUTextEncoderWrapper:
|
class CPUTextEncoderWrapper:
|
||||||
def __init__(self, text_encoder):
|
def __init__(self, text_encoder, torch_dtype):
|
||||||
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
|
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
|
||||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
input_device = x.device
|
input_device = x.device
|
||||||
return [self.text_encoder(x.to(self.text_encoder.device))[0].to(input_device)]
|
return [self.text_encoder(x.to(self.text_encoder.device))[0].to(input_device).to(self.torch_dtype)]
|
||||||
|
|
||||||
|
|
||||||
class SD(InpaintModel):
|
class SD(InpaintModel):
|
||||||
@ -61,11 +62,11 @@ class SD(InpaintModel):
|
|||||||
))
|
))
|
||||||
|
|
||||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||||
|
torch_dtype = torch.float16 if use_gpu else torch.float32
|
||||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
revision="fp16" if use_gpu else "main",
|
revision="fp16" if use_gpu else "main",
|
||||||
torch_dtype=torch.float16 if use_gpu else torch.float32,
|
torch_dtype=torch_dtype,
|
||||||
use_auth_token=kwargs["hf_access_token"],
|
use_auth_token=kwargs["hf_access_token"],
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
@ -75,11 +76,10 @@ class SD(InpaintModel):
|
|||||||
|
|
||||||
if kwargs['sd_cpu_textencoder']:
|
if kwargs['sd_cpu_textencoder']:
|
||||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder)
|
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
@torch.cuda.amp.autocast()
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
|
@ -14,7 +14,7 @@ save_dir.mkdir(exist_ok=True, parents=True)
|
|||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
def get_data(fx: float = 1, fy: float = 1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||||
img = cv2.imread(str(img_p))
|
img = cv2.imread(str(img_p))
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||||
@ -36,7 +36,10 @@ def get_config(strategy, **kwargs):
|
|||||||
return Config(**data)
|
return Config(**data)
|
||||||
|
|
||||||
|
|
||||||
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
def assert_equal(model, config, gt_name,
|
||||||
|
fx: float = 1, fy: float = 1,
|
||||||
|
img_p=current_dir / "image.png",
|
||||||
|
mask_p=current_dir / "mask.png"):
|
||||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
print(f"Input image shape: {img.shape}")
|
print(f"Input image shape: {img.shape}")
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
@ -157,105 +160,40 @@ def test_fcf(strategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ['cpu', 'cuda'])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
||||||
def test_sd(strategy, sampler):
|
|
||||||
def callback(i, t, latents):
|
|
||||||
print(f"sd_step_{i}")
|
|
||||||
|
|
||||||
sd_steps = 50
|
|
||||||
model = ModelManager(name="sd1.4",
|
|
||||||
device=device,
|
|
||||||
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
|
||||||
sd_run_local=False,
|
|
||||||
sd_disable_nsfw=False,
|
|
||||||
sd_cpu_textencoder=False,
|
|
||||||
callback=callback)
|
|
||||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"sd_{strategy.capitalize()}_{sampler}_result.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
|
||||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
||||||
def test_sd_run_local(strategy, sampler, disable_nsfw, cpu_textencoder):
|
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||||
|
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||||
def callback(i, t, latents):
|
def callback(i, t, latents):
|
||||||
print(f"sd_step_{i}")
|
print(f"sd_step_{i}")
|
||||||
|
|
||||||
sd_steps = 50
|
if sd_device == 'cuda' and not torch.cuda.is_available():
|
||||||
model = ModelManager(
|
return
|
||||||
name="sd1.4",
|
|
||||||
device=device,
|
sd_steps = 1
|
||||||
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
|
model = ModelManager(name="sd1.5",
|
||||||
hf_access_token=None,
|
device=sd_device,
|
||||||
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=disable_nsfw,
|
sd_disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
)
|
|
||||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"sd_{strategy.capitalize()}_{sampler}_local_disablensfw_{disable_nsfw}_cputextencoder_{cpu_textencoder}_result.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
|
||||||
def test_runway_sd_1_5(strategy, sampler):
|
|
||||||
def callback(i, t, latents):
|
|
||||||
print(f"sd_step_{i}")
|
|
||||||
|
|
||||||
sd_steps = 20
|
|
||||||
model = ModelManager(name="sd1.5",
|
|
||||||
device=device,
|
|
||||||
hf_access_token=None,
|
|
||||||
sd_run_local=True,
|
|
||||||
sd_disable_nsfw=True,
|
|
||||||
sd_cpu_textencoder=True,
|
|
||||||
callback=callback)
|
callback=callback)
|
||||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
name = f"{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
f"runway_sd_{strategy.capitalize()}_{sampler}_result.png",
|
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1.3
|
fx=1.3
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"runway_sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
|
||||||
fy=1.3
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||||
|
@ -2,7 +2,7 @@ torch>=1.9.0
|
|||||||
opencv-python
|
opencv-python
|
||||||
flask_cors
|
flask_cors
|
||||||
flask==1.1.4
|
flask==1.1.4
|
||||||
flaskwebgui
|
flaskwebgui==0.3.5
|
||||||
tqdm
|
tqdm
|
||||||
pydantic
|
pydantic
|
||||||
loguru
|
loguru
|
||||||
|
Loading…
Reference in New Issue
Block a user