fix cv2 params
This commit is contained in:
parent
521a1e2858
commit
f4fcece180
@ -11,6 +11,9 @@
|
||||
<a href="https://colab.research.google.com/drive/1e3ZkAJxvkK3uzaTGu91N9TvI_Mahs0Wb?usp=sharing">
|
||||
<img alt="Open in Colab" src="https://colab.research.google.com/assets/colab-badge.svg" />
|
||||
</a>
|
||||
<a href="https://www.python.org/downloads/">
|
||||
<img alt="python version" src="https://img.shields.io/badge/python-3.7+-blue.svg" />
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/cwq1913/lama-cleaner">
|
||||
<img alt="version" src="https://img.shields.io/docker/pulls/cwq1913/lama-cleaner" />
|
||||
</a>
|
||||
|
@ -46,6 +46,9 @@ export default async function inpaint(
|
||||
fd.append('sdSampler', settings.sdSampler.toString())
|
||||
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
||||
|
||||
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||
|
||||
if (sizeLimit === undefined) {
|
||||
fd.append('sizeLimit', '1080')
|
||||
} else {
|
||||
|
@ -2,6 +2,11 @@ import cv2
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
flag_map = {
|
||||
"INPAINT_NS": cv2.INPAINT_NS,
|
||||
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
||||
}
|
||||
|
||||
class OpenCV2(InpaintModel):
|
||||
pad_mod = 1
|
||||
|
||||
@ -15,5 +20,5 @@ class OpenCV2(InpaintModel):
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
|
||||
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
||||
return cur_res
|
||||
|
@ -44,3 +44,7 @@ class Config(BaseModel):
|
||||
sd_sampler: str = SDSampler.ddim
|
||||
# -1 mean random seed
|
||||
sd_seed: int = 42
|
||||
|
||||
# cv2
|
||||
cv2_flag: str = 'INPAINT_NS'
|
||||
cv2_radius: int = 4
|
||||
|
@ -124,6 +124,8 @@ def process():
|
||||
sd_guidance_scale=form["sdGuidanceScale"],
|
||||
sd_sampler=form["sdSampler"],
|
||||
sd_seed=form["sdSeed"],
|
||||
cv2_flag=form["cv2Flag"],
|
||||
cv2_radius=form['cv2Radius']
|
||||
)
|
||||
|
||||
if config.sd_seed == -1:
|
||||
|
@ -228,3 +228,22 @@ def test_sd_run_local(strategy, sampler, disable_nsfw):
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
@pytest.mark.parametrize("cv2_flag", ['INPAINT_NS', 'INPAINT_TELEA'])
|
||||
@pytest.mark.parametrize("cv2_radius", [3, 15])
|
||||
def test_cv2(strategy, cv2_flag, cv2_radius):
|
||||
model = ModelManager(
|
||||
name="cv2",
|
||||
device=device,
|
||||
)
|
||||
cfg = get_config(strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user