add ldm model
This commit is contained in:
parent
f09d40cbef
commit
f9b96cf218
42
README.md
42
README.md
@ -1,10 +1,10 @@
|
||||
# Lama-cleaner: Image inpainting tool powered by [LaMa](https://github.com/saic-mdal/lama)
|
||||
|
||||
This project is mainly used for selfhosting LaMa model, some interaction improvements may be added later.
|
||||
# Lama-cleaner: Image inpainting tool powered by SOTA AI model
|
||||
|
||||
https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b59b-7c1ee24a4507.mp4
|
||||
|
||||
|
||||
- [x] Support multiple model architectures
|
||||
1. [LaMa](https://github.com/saic-mdal/lama)
|
||||
1. [LDM](https://github.com/CompVis/latent-diffusion)
|
||||
- [x] High resolution support
|
||||
- [x] Multi stroke support. Press and hold the `cmd/ctrl` key to enable multi stroke mode.
|
||||
- [x] Zoom & Pan
|
||||
@ -12,8 +12,29 @@ https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b
|
||||
|
||||
## Quick Start
|
||||
|
||||
- Install requirements: `pip3 install -r requirements.txt`
|
||||
- Start server: `python3 main.py --device=cuda --port=8080`
|
||||
Install requirements: `pip3 install -r requirements.txt`
|
||||
|
||||
### Start server with LaMa model
|
||||
|
||||
```bash
|
||||
python3 main.py --device=cuda --port=8080 --model=lama
|
||||
```
|
||||
|
||||
### Start server with LDM model
|
||||
|
||||
```bash
|
||||
python3 main.py --device=cuda --port=8080 --model=ldm --ldm-steps=50
|
||||
```
|
||||
|
||||
`--ldm-steps`: The larger the value, the better the result, but it will be more time-consuming
|
||||
|
||||
Diffusion model is **MUCH MORE** slower than GANs(1080x720 image takes 8s on 3090), but it's possible to get better
|
||||
results than LaMa.
|
||||
|
||||
Blogs about diffusion models:
|
||||
|
||||
- https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
|
||||
- https://yang-song.github.io/blog/2021/score/
|
||||
|
||||
## Development
|
||||
|
||||
@ -21,8 +42,8 @@ Only needed if you plan to modify the frontend and recompile yourself.
|
||||
|
||||
### Fronted
|
||||
|
||||
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures),
|
||||
You can experience their great online services [here](https://cleanup.pictures/).
|
||||
Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their
|
||||
great online services [here](https://cleanup.pictures/).
|
||||
|
||||
- Install dependencies:`cd lama_cleaner/app/ && yarn`
|
||||
- Start development server: `yarn dev`
|
||||
@ -30,8 +51,8 @@ You can experience their great online services [here](https://cleanup.pictures/)
|
||||
|
||||
## Docker
|
||||
|
||||
Run within a Docker container. Set the `CACHE_DIR` to models location path.
|
||||
Optionally add a `-d` option to the `docker run` command below to run as a daemon.
|
||||
Run within a Docker container. Set the `CACHE_DIR` to models location path. Optionally add a `-d` option to
|
||||
the `docker run` command below to run as a daemon.
|
||||
|
||||
### Build Docker image
|
||||
|
||||
@ -54,6 +75,7 @@ docker run --gpus all -p 8080:8080 -e CACHE_DIR=/app/models -v $(pwd)/models:/ap
|
||||
Then open [http://localhost:8080](http://localhost:8080)
|
||||
|
||||
## Like My Work?
|
||||
|
||||
<a href="https://www.buymeacoffee.com/Sanster">
|
||||
<img height="50em" src="https://cdn.buymeacoffee.com/buttons/v2/default-blue.png" alt="Sanster" />
|
||||
</a>
|
||||
|
@ -7,13 +7,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
|
||||
|
||||
def download_model(url=LAMA_MODEL_URL):
|
||||
def download_model(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
|
56
lama_cleaner/lama/__init__.py
Normal file
56
lama_cleaner/lama/__init__.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
|
||||
|
||||
class LaMa:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
device = self.device
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=8)
|
||||
mask = pad_img_to_modulo(mask, mod=8)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
start = time.time()
|
||||
inpainted_image = self.model(image, mask)
|
||||
|
||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
||||
return cur_res
|
415
lama_cleaner/ldm/__init__.py
Normal file
415
lama_cleaner/ldm/__init__.py
Normal file
@ -0,0 +1,415 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model
|
||||
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
timestep_embedding
|
||||
|
||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||
"LDM_ENCODE_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
||||
)
|
||||
|
||||
LDM_DECODE_MODEL_URL = os.environ.get(
|
||||
"LDM_DECODE_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
||||
)
|
||||
|
||||
LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
||||
"LDM_DIFFUSION_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
||||
)
|
||||
|
||||
|
||||
class DDPM(nn.Module):
|
||||
# classic DDPM with Gaussian diffusion, in image space
|
||||
def __init__(self,
|
||||
device,
|
||||
timesteps=1000,
|
||||
beta_schedule="linear",
|
||||
linear_start=0.0015,
|
||||
linear_end=0.0205,
|
||||
cosine_s=0.008,
|
||||
original_elbo_weight=0.,
|
||||
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
||||
l_simple_weight=1.,
|
||||
parameterization="eps", # all assuming fixed variance schedules
|
||||
use_positional_encodings=False):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.parameterization = parameterization
|
||||
self.use_positional_encodings = use_positional_encodings
|
||||
|
||||
self.v_posterior = v_posterior
|
||||
self.original_elbo_weight = original_elbo_weight
|
||||
self.l_simple_weight = l_simple_weight
|
||||
|
||||
self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,
|
||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(self.device, beta_schedule, timesteps, linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
||||
1. - alphas_cumprod) + self.v_posterior * betas
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
||||
self.register_buffer('posterior_mean_coef1', to_torch(
|
||||
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
||||
self.register_buffer('posterior_mean_coef2', to_torch(
|
||||
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
||||
|
||||
if self.parameterization == "eps":
|
||||
lvlb_weights = self.betas ** 2 / (
|
||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
||||
elif self.parameterization == "x0":
|
||||
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
||||
else:
|
||||
raise NotImplementedError("mu not supported")
|
||||
# TODO how to choose this term
|
||||
lvlb_weights[0] = lvlb_weights[1]
|
||||
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
||||
assert not torch.isnan(self.lvlb_weights).all()
|
||||
|
||||
|
||||
class LatentDiffusion(DDPM):
|
||||
def __init__(self,
|
||||
diffusion_model,
|
||||
device,
|
||||
cond_stage_key="image",
|
||||
cond_stage_trainable=False,
|
||||
concat_mode=True,
|
||||
scale_factor=1.0,
|
||||
scale_by_std=False,
|
||||
*args, **kwargs):
|
||||
self.num_timesteps_cond = 1
|
||||
self.scale_by_std = scale_by_std
|
||||
super().__init__(device, *args, **kwargs)
|
||||
self.diffusion_model = diffusion_model
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.num_downs = 2
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def make_cond_schedule(self, ):
|
||||
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
||||
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
||||
self.cond_ids[:self.num_timesteps_cond] = ids
|
||||
|
||||
def register_schedule(self,
|
||||
given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
||||
|
||||
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
||||
if self.shorten_cond_schedule:
|
||||
self.make_cond_schedule()
|
||||
|
||||
def apply_model(self, x_noisy, t, cond):
|
||||
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
|
||||
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
|
||||
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
|
||||
return x_recon
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear"):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
# array([1])
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta, verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, steps, conditioning, batch_size, shape):
|
||||
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
# samples: 1,3,128,128
|
||||
return self.ddim_sampling(conditioning,
|
||||
size,
|
||||
quantize_denoised=False,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=0,
|
||||
temperature=1.,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
ddim_use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0.):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device) # 用了
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps # 用了
|
||||
|
||||
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout)
|
||||
img, _ = outs
|
||||
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised: # 没用
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.: # 没用
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
|
||||
def load_jit_model(url, device):
|
||||
model_path = download_model(url)
|
||||
model = torch.jit.load(model_path).to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LDM:
|
||||
def __init__(self, device, steps=50):
|
||||
self.device = device
|
||||
|
||||
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
||||
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
||||
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
||||
|
||||
model = LatentDiffusion(self.diffusion_model, device)
|
||||
self.sampler = DDIMSampler(model)
|
||||
self.steps = steps
|
||||
|
||||
def _norm(self, tensor):
|
||||
return tensor * 2.0 - 1.0
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# image [1,3,512,512] float32
|
||||
# mask: [1,1,512,512] float32
|
||||
# masked_image: [1,3,512,512] float32
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=32)
|
||||
mask = pad_img_to_modulo(mask, mod=32)
|
||||
padded_height, padded_width = image.shape[1:]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# crop 512 x 512
|
||||
if padded_width <= 512 or padded_height <= 512:
|
||||
np_img = self._forward(image, mask, self.device)
|
||||
else:
|
||||
print("Try to zoom in")
|
||||
# zoom in
|
||||
# x,y,w,h
|
||||
# box = self.box_from_bitmap(mask)
|
||||
box = self.find_main_content(mask)
|
||||
if box is None:
|
||||
print("No bbox found")
|
||||
np_img = self._forward(image, mask, self.device)
|
||||
else:
|
||||
print(f"box: {box}")
|
||||
box_x, box_y, box_w, box_h = box
|
||||
cx = box_x + box_w // 2
|
||||
cy = box_y + box_h // 2
|
||||
|
||||
w = max(512, box_w)
|
||||
h = max(512, box_h)
|
||||
|
||||
left = max(cx - w // 2, 0)
|
||||
top = max(cy - h // 2, 0)
|
||||
right = min(cx + w // 2, origin_width)
|
||||
bottom = min(cy + h // 2, origin_height)
|
||||
|
||||
x = left
|
||||
y = top
|
||||
w = right - left
|
||||
h = bottom - top
|
||||
|
||||
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
|
||||
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
|
||||
|
||||
print(f"Apply zoom in size width x height: {crop_img.shape}")
|
||||
|
||||
crop_img_height, crop_img_width = crop_img.shape[1:]
|
||||
|
||||
crop_img = pad_img_to_modulo(crop_img, mod=32)
|
||||
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
|
||||
# RGB
|
||||
np_img = self._forward(crop_img, crop_mask, self.device)
|
||||
|
||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
|
||||
np_img = image
|
||||
# BGR to RGB
|
||||
# np_img = image[:, :, ::-1]
|
||||
|
||||
np_img = np_img[0:origin_height, 0:origin_width, :]
|
||||
np_img = np_img[:, :, ::-1]
|
||||
|
||||
return np_img
|
||||
|
||||
def _forward(self, image, mask, device):
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
masked_image = (1 - mask) * image
|
||||
|
||||
image = self._norm(image)
|
||||
mask = self._norm(mask)
|
||||
masked_image = self._norm(masked_image)
|
||||
|
||||
c = self.cond_stage_model_encode(masked_image)
|
||||
|
||||
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||
|
||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||
samples_ddim = self.sampler.sample(steps=self.steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape)
|
||||
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
|
||||
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
np_img = inpainted.astype(np.uint8)
|
||||
return np_img
|
||||
|
||||
def find_main_content(self, bitmap: np.ndarray):
|
||||
th2 = bitmap[0].astype(np.uint8)
|
||||
row_sum = th2.sum(1)
|
||||
col_sum = th2.sum(0)
|
||||
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
|
||||
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
|
||||
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
|
||||
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
|
||||
|
||||
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
|
||||
return left, top, right - left, bottom - top
|
||||
|
||||
def box_from_bitmap(self, bitmap):
|
||||
"""
|
||||
bitmap: single map with shape (NUM_CLASSES, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
contours, _ = cv2.findContours(
|
||||
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
|
||||
)
|
||||
|
||||
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
|
||||
num_contours = len(contours)
|
||||
print(f"contours size: {num_contours}")
|
||||
if num_contours != 1:
|
||||
return None
|
||||
|
||||
# x,y,w,h
|
||||
return cv2.boundingRect(contours[0])
|
86
lama_cleaner/ldm/utils.py
Normal file
86
lama_cleaner/ldm/utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2).to(device)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=device)
|
||||
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
63
main.py
63
main.py
@ -4,14 +4,14 @@ import argparse
|
||||
import io
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from distutils.util import strtobool
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lama_cleaner.lama import LaMa
|
||||
from lama_cleaner.ldm import LDM
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
@ -24,13 +24,10 @@ from flask import Flask, request, send_file
|
||||
from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
download_model,
|
||||
load_img,
|
||||
norm_img,
|
||||
numpy_to_bytes,
|
||||
pad_img_to_modulo,
|
||||
resize_max_size,
|
||||
)
|
||||
resize_max_size, )
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
@ -55,6 +52,7 @@ device = None
|
||||
@app.route("/inpaint", methods=["POST"])
|
||||
def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
image = load_img(input["image"].read())
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
@ -74,14 +72,7 @@ def process():
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
mask = norm_img(mask)
|
||||
|
||||
res_np_img = run(image, mask)
|
||||
|
||||
# resize to original size
|
||||
# res_np_img = cv2.resize(
|
||||
# res_np_img,
|
||||
# dsize=(original_shape[1], original_shape[0]),
|
||||
# interpolation=interpolation,
|
||||
# )
|
||||
res_np_img = model(image, mask)
|
||||
|
||||
return send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img)),
|
||||
@ -96,35 +87,12 @@ def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
||||
|
||||
def run(image, mask):
|
||||
"""
|
||||
image: [C, H, W]
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=8)
|
||||
mask = pad_img_to_modulo(mask, mod=8)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
inpainted_image = model(image, mask)
|
||||
|
||||
print(f"process time: {(time.time() - start)*1000}ms")
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
||||
return cur_res
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument("--ldm-steps", default=50, type=int, help="Steps for DDIM sampling process."
|
||||
"The larger the value, the better the result, but it will be more time-consuming")
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
return parser.parse_args()
|
||||
@ -136,16 +104,13 @@ def main():
|
||||
args = get_args_parser()
|
||||
device = torch.device(args.device)
|
||||
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
|
||||
if args.model == "lama":
|
||||
model = LaMa(device)
|
||||
elif args.model == "ldm":
|
||||
model = LDM(device, steps=args.ldm_steps)
|
||||
else:
|
||||
model_path = download_model()
|
||||
raise NotImplementedError(f"Not supported model: {args.model}")
|
||||
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
||||
|
||||
|
||||
|
@ -2,3 +2,4 @@ torch>=1.8.2
|
||||
opencv-python
|
||||
flask_cors
|
||||
flask
|
||||
tqdm
|
||||
|
Loading…
Reference in New Issue
Block a user