fix cv2 params
This commit is contained in:
parent
521a1e2858
commit
f4fcece180
@ -11,6 +11,9 @@
|
|||||||
<a href="https://colab.research.google.com/drive/1e3ZkAJxvkK3uzaTGu91N9TvI_Mahs0Wb?usp=sharing">
|
<a href="https://colab.research.google.com/drive/1e3ZkAJxvkK3uzaTGu91N9TvI_Mahs0Wb?usp=sharing">
|
||||||
<img alt="Open in Colab" src="https://colab.research.google.com/assets/colab-badge.svg" />
|
<img alt="Open in Colab" src="https://colab.research.google.com/assets/colab-badge.svg" />
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://www.python.org/downloads/">
|
||||||
|
<img alt="python version" src="https://img.shields.io/badge/python-3.7+-blue.svg" />
|
||||||
|
</a>
|
||||||
<a href="https://hub.docker.com/r/cwq1913/lama-cleaner">
|
<a href="https://hub.docker.com/r/cwq1913/lama-cleaner">
|
||||||
<img alt="version" src="https://img.shields.io/docker/pulls/cwq1913/lama-cleaner" />
|
<img alt="version" src="https://img.shields.io/docker/pulls/cwq1913/lama-cleaner" />
|
||||||
</a>
|
</a>
|
||||||
|
@ -46,6 +46,9 @@ export default async function inpaint(
|
|||||||
fd.append('sdSampler', settings.sdSampler.toString())
|
fd.append('sdSampler', settings.sdSampler.toString())
|
||||||
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
||||||
|
|
||||||
|
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||||
|
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||||
|
|
||||||
if (sizeLimit === undefined) {
|
if (sizeLimit === undefined) {
|
||||||
fd.append('sizeLimit', '1080')
|
fd.append('sizeLimit', '1080')
|
||||||
} else {
|
} else {
|
||||||
|
@ -2,6 +2,11 @@ import cv2
|
|||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
flag_map = {
|
||||||
|
"INPAINT_NS": cv2.INPAINT_NS,
|
||||||
|
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
||||||
|
}
|
||||||
|
|
||||||
class OpenCV2(InpaintModel):
|
class OpenCV2(InpaintModel):
|
||||||
pad_mod = 1
|
pad_mod = 1
|
||||||
|
|
||||||
@ -15,5 +20,5 @@ class OpenCV2(InpaintModel):
|
|||||||
mask: [H, W, 1]
|
mask: [H, W, 1]
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
|
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
||||||
return cur_res
|
return cur_res
|
||||||
|
@ -44,3 +44,7 @@ class Config(BaseModel):
|
|||||||
sd_sampler: str = SDSampler.ddim
|
sd_sampler: str = SDSampler.ddim
|
||||||
# -1 mean random seed
|
# -1 mean random seed
|
||||||
sd_seed: int = 42
|
sd_seed: int = 42
|
||||||
|
|
||||||
|
# cv2
|
||||||
|
cv2_flag: str = 'INPAINT_NS'
|
||||||
|
cv2_radius: int = 4
|
||||||
|
@ -124,6 +124,8 @@ def process():
|
|||||||
sd_guidance_scale=form["sdGuidanceScale"],
|
sd_guidance_scale=form["sdGuidanceScale"],
|
||||||
sd_sampler=form["sdSampler"],
|
sd_sampler=form["sdSampler"],
|
||||||
sd_seed=form["sdSeed"],
|
sd_seed=form["sdSeed"],
|
||||||
|
cv2_flag=form["cv2Flag"],
|
||||||
|
cv2_radius=form['cv2Radius']
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
if config.sd_seed == -1:
|
||||||
|
@ -228,3 +228,22 @@ def test_sd_run_local(strategy, sampler, disable_nsfw):
|
|||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("cv2_flag", ['INPAINT_NS', 'INPAINT_TELEA'])
|
||||||
|
@pytest.mark.parametrize("cv2_radius", [3, 15])
|
||||||
|
def test_cv2(strategy, cv2_flag, cv2_radius):
|
||||||
|
model = ModelManager(
|
||||||
|
name="cv2",
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
cfg = get_config(strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius)
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user