add github model url; fix cpu tensor
This commit is contained in:
parent
01c7f3b77d
commit
e11aed0b1e
@ -8,33 +8,28 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
from skimage.color import rgb2gray
|
|
||||||
from skimage.feature import canny
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
|
||||||
ZITS_INPAINT_MODEL_URL = os.environ.get(
|
ZITS_INPAINT_MODEL_URL = os.environ.get(
|
||||||
"ZITS_INPAINT_MODEL_URL",
|
"ZITS_INPAINT_MODEL_URL",
|
||||||
# "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
"https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
|
||||||
"/Users/qing/code/github/ZITS_inpainting/zits-inpaint.pt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
|
ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
|
||||||
"ZITS_EDGE_LINE_MODEL_URL",
|
"ZITS_EDGE_LINE_MODEL_URL",
|
||||||
# "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
"https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
|
||||||
"/Users/qing/code/github/ZITS_inpainting/zits-edge-line.pt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
|
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
|
||||||
"ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
|
"ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
|
||||||
"https://github.com/Sanster/models/releases/download/add_ldm/zits-structure-upsample.pt",
|
"https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
|
ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
|
||||||
"ZITS_WIRE_FRAME_MODEL_URL",
|
"ZITS_WIRE_FRAME_MODEL_URL",
|
||||||
# "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
"https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
|
||||||
"/Users/qing/code/github/ZITS_inpainting/zits-wireframe.pt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -158,8 +153,19 @@ def load_image(img, mask, device, sigma256=3.0):
|
|||||||
mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
|
mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
|
||||||
mask_512[mask_512 > 0] = 255
|
mask_512[mask_512 > 0] = 255
|
||||||
|
|
||||||
gray_256 = rgb2gray(img_256)
|
# original skimage implemention
|
||||||
edge_256 = canny(gray_256, sigma=sigma256, mask=None).astype(float)
|
# https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
|
||||||
|
# low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
|
||||||
|
# high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
|
||||||
|
gray_256 = skimage.color.rgb2gray(img_256)
|
||||||
|
edge_256 = skimage.feature.canny(gray_256, sigma=sigma256, mask=None).astype(float)
|
||||||
|
# cv2.imwrite("skimage_gray.jpg", (_gray_256*255).astype(np.uint8))
|
||||||
|
# cv2.imwrite("skimage_edge.jpg", (_edge_256*255).astype(np.uint8))
|
||||||
|
|
||||||
|
# gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
|
||||||
|
# gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(3,3), sigmaX=sigma256, sigmaY=sigma256)
|
||||||
|
# edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2))
|
||||||
|
# cv2.imwrite("edge.jpg", edge_256)
|
||||||
|
|
||||||
# line
|
# line
|
||||||
img_512 = resize(img, 512, 512)
|
img_512 = resize(img, 512, 512)
|
||||||
@ -307,7 +313,7 @@ class ZITS(InpaintModel):
|
|||||||
items["rel_pos"], items["direct"])
|
items["rel_pos"], items["direct"])
|
||||||
|
|
||||||
inpainted_image = inpainted_image * 255.0
|
inpainted_image = inpainted_image * 255.0
|
||||||
inpainted_image = inpainted_image.permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
inpainted_image = inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
||||||
inpainted_image = inpainted_image[:, :, ::-1]
|
inpainted_image = inpainted_image[:, :, ::-1]
|
||||||
|
|
||||||
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
||||||
|
@ -9,4 +9,4 @@ loguru
|
|||||||
pytest
|
pytest
|
||||||
yacs
|
yacs
|
||||||
markupsafe==2.0.1
|
markupsafe==2.0.1
|
||||||
scikit-image==0.17.2
|
scikit-image
|
Loading…
Reference in New Issue
Block a user