diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index 9f2f175..3faf5eb 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -8,33 +8,28 @@ import torch.nn.functional as F from lama_cleaner.helper import get_cache_path_by_url, load_jit_model from lama_cleaner.schema import Config -from skimage.color import rgb2gray -from skimage.feature import canny import numpy as np from lama_cleaner.model.base import InpaintModel ZITS_INPAINT_MODEL_URL = os.environ.get( "ZITS_INPAINT_MODEL_URL", - # "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", - "/Users/qing/code/github/ZITS_inpainting/zits-inpaint.pt" + "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt", ) ZITS_EDGE_LINE_MODEL_URL = os.environ.get( "ZITS_EDGE_LINE_MODEL_URL", - # "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", - "/Users/qing/code/github/ZITS_inpainting/zits-edge-line.pt" + "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt", ) ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get( "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL", - "https://github.com/Sanster/models/releases/download/add_ldm/zits-structure-upsample.pt", + "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt", ) ZITS_WIRE_FRAME_MODEL_URL = os.environ.get( "ZITS_WIRE_FRAME_MODEL_URL", - # "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", - "/Users/qing/code/github/ZITS_inpainting/zits-wireframe.pt" + "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt", ) @@ -158,8 +153,19 @@ def load_image(img, mask, device, sigma256=3.0): mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA) mask_512[mask_512 > 0] = 255 - gray_256 = rgb2gray(img_256) - edge_256 = canny(gray_256, sigma=sigma256, mask=None).astype(float) + # original skimage implemention + # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny + # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max. + # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max. + gray_256 = skimage.color.rgb2gray(img_256) + edge_256 = skimage.feature.canny(gray_256, sigma=sigma256, mask=None).astype(float) + # cv2.imwrite("skimage_gray.jpg", (_gray_256*255).astype(np.uint8)) + # cv2.imwrite("skimage_edge.jpg", (_edge_256*255).astype(np.uint8)) + + # gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY) + # gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(3,3), sigmaX=sigma256, sigmaY=sigma256) + # edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2)) + # cv2.imwrite("edge.jpg", edge_256) # line img_512 = resize(img, 512, 512) @@ -307,7 +313,7 @@ class ZITS(InpaintModel): items["rel_pos"], items["direct"]) inpainted_image = inpainted_image * 255.0 - inpainted_image = inpainted_image.permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) + inpainted_image = inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) inpainted_image = inpainted_image[:, :, ::-1] # cv2.imwrite("inpainted.jpg", inpainted_image) diff --git a/requirements.txt b/requirements.txt index ea6ac6d..87700e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ loguru pytest yacs markupsafe==2.0.1 -scikit-image==0.17.2 \ No newline at end of file +scikit-image \ No newline at end of file