IOPaint/lama_cleaner/helper.py
2021-11-15 20:11:46 +01:00

67 lines
1.8 KiB
Python

import os
import sys
from urllib.parse import urlparse
import cv2
import numpy as np
import torch
from torch.hub import download_url_to_file, get_dir
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
def download_model(url=LAMA_MODEL_URL):
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
if not os.path.isdir(model_dir):
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
return cached_file
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
data = cv2.imencode(".jpg", image_numpy)[1]
image_bytes = data.tobytes()
return image_bytes
def load_img(img_bytes, gray: bool = False, norm: bool = True):
nparr = np.frombuffer(img_bytes, np.uint8)
if gray:
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)[:, :, np.newaxis]
else:
np_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
if norm:
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(
img,
((0, 0), (0, out_height - height), (0, out_width - width)),
mode="symmetric",
)