diff --git a/README.md b/README.md
index 7ef4ffa..d91f03d 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,9 @@
+
+
+
diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts
index 3f1dc6a..2cf291a 100644
--- a/lama_cleaner/app/src/adapters/inpainting.ts
+++ b/lama_cleaner/app/src/adapters/inpainting.ts
@@ -46,6 +46,9 @@ export default async function inpaint(
fd.append('sdSampler', settings.sdSampler.toString())
fd.append('sdSeed', seed ? seed.toString() : '-1')
+ fd.append('cv2Radius', settings.cv2Radius.toString())
+ fd.append('cv2Flag', settings.cv2Flag.toString())
+
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {
diff --git a/lama_cleaner/model/opencv2.py b/lama_cleaner/model/opencv2.py
index 1a69baa..1802ccd 100644
--- a/lama_cleaner/model/opencv2.py
+++ b/lama_cleaner/model/opencv2.py
@@ -2,6 +2,11 @@ import cv2
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
+flag_map = {
+ "INPAINT_NS": cv2.INPAINT_NS,
+ "INPAINT_TELEA": cv2.INPAINT_TELEA
+}
+
class OpenCV2(InpaintModel):
pad_mod = 1
@@ -15,5 +20,5 @@ class OpenCV2(InpaintModel):
mask: [H, W, 1]
return: BGR IMAGE
"""
- cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
+ cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
return cur_res
diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py
index cfc8aba..265cc4c 100644
--- a/lama_cleaner/schema.py
+++ b/lama_cleaner/schema.py
@@ -44,3 +44,7 @@ class Config(BaseModel):
sd_sampler: str = SDSampler.ddim
# -1 mean random seed
sd_seed: int = 42
+
+ # cv2
+ cv2_flag: str = 'INPAINT_NS'
+ cv2_radius: int = 4
diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py
index c72c782..f83184a 100644
--- a/lama_cleaner/server.py
+++ b/lama_cleaner/server.py
@@ -124,6 +124,8 @@ def process():
sd_guidance_scale=form["sdGuidanceScale"],
sd_sampler=form["sdSampler"],
sd_seed=form["sdSeed"],
+ cv2_flag=form["cv2Flag"],
+ cv2_radius=form['cv2Radius']
)
if config.sd_seed == -1:
diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py
index f2eac96..383231f 100644
--- a/lama_cleaner/tests/test_model.py
+++ b/lama_cleaner/tests/test_model.py
@@ -228,3 +228,22 @@ def test_sd_run_local(strategy, sampler, disable_nsfw):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
)
+
+@pytest.mark.parametrize(
+ "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
+)
+@pytest.mark.parametrize("cv2_flag", ['INPAINT_NS', 'INPAINT_TELEA'])
+@pytest.mark.parametrize("cv2_radius", [3, 15])
+def test_cv2(strategy, cv2_flag, cv2_radius):
+ model = ModelManager(
+ name="cv2",
+ device=device,
+ )
+ cfg = get_config(strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius)
+ assert_equal(
+ model,
+ cfg,
+ f"sd_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
+ img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
+ mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
+ )