From f9b96cf218a849c8b8991cad6e6f9a03d5fa3472 Mon Sep 17 00:00:00 2001 From: Sanster Date: Fri, 4 Mar 2022 13:44:53 +0800 Subject: [PATCH] add ldm model --- README.md | 42 +++- lama_cleaner/helper.py | 7 +- lama_cleaner/lama/__init__.py | 56 +++++ lama_cleaner/ldm/__init__.py | 415 ++++++++++++++++++++++++++++++++++ lama_cleaner/ldm/utils.py | 86 +++++++ main.py | 63 ++---- requirements.txt | 1 + 7 files changed, 605 insertions(+), 65 deletions(-) create mode 100644 lama_cleaner/lama/__init__.py create mode 100644 lama_cleaner/ldm/__init__.py create mode 100644 lama_cleaner/ldm/utils.py diff --git a/README.md b/README.md index 097d77a..9d36c8d 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# Lama-cleaner: Image inpainting tool powered by [LaMa](https://github.com/saic-mdal/lama) - -This project is mainly used for selfhosting LaMa model, some interaction improvements may be added later. +# Lama-cleaner: Image inpainting tool powered by SOTA AI model https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b59b-7c1ee24a4507.mp4 - +- [x] Support multiple model architectures + 1. [LaMa](https://github.com/saic-mdal/lama) + 1. [LDM](https://github.com/CompVis/latent-diffusion) - [x] High resolution support - [x] Multi stroke support. Press and hold the `cmd/ctrl` key to enable multi stroke mode. - [x] Zoom & Pan @@ -12,8 +12,29 @@ https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b ## Quick Start -- Install requirements: `pip3 install -r requirements.txt` -- Start server: `python3 main.py --device=cuda --port=8080` +Install requirements: `pip3 install -r requirements.txt` + +### Start server with LaMa model + +```bash +python3 main.py --device=cuda --port=8080 --model=lama +``` + +### Start server with LDM model + +```bash +python3 main.py --device=cuda --port=8080 --model=ldm --ldm-steps=50 +``` + +`--ldm-steps`: The larger the value, the better the result, but it will be more time-consuming + +Diffusion model is **MUCH MORE** slower than GANs(1080x720 image takes 8s on 3090), but it's possible to get better +results than LaMa. + +Blogs about diffusion models: + +- https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ +- https://yang-song.github.io/blog/2021/score/ ## Development @@ -21,8 +42,8 @@ Only needed if you plan to modify the frontend and recompile yourself. ### Fronted -Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), -You can experience their great online services [here](https://cleanup.pictures/). +Frontend code are modified from [cleanup.pictures](https://github.com/initml/cleanup.pictures), You can experience their +great online services [here](https://cleanup.pictures/). - Install dependencies:`cd lama_cleaner/app/ && yarn` - Start development server: `yarn dev` @@ -30,8 +51,8 @@ You can experience their great online services [here](https://cleanup.pictures/) ## Docker -Run within a Docker container. Set the `CACHE_DIR` to models location path. -Optionally add a `-d` option to the `docker run` command below to run as a daemon. +Run within a Docker container. Set the `CACHE_DIR` to models location path. Optionally add a `-d` option to +the `docker run` command below to run as a daemon. ### Build Docker image @@ -54,6 +75,7 @@ docker run --gpus all -p 8080:8080 -e CACHE_DIR=/app/models -v $(pwd)/models:/ap Then open [http://localhost:8080](http://localhost:8080) ## Like My Work? + Sanster diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index cbf8b92..943d1be 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -7,13 +7,8 @@ import numpy as np import torch from torch.hub import download_url_to_file, get_dir -LAMA_MODEL_URL = os.environ.get( - "LAMA_MODEL_URL", - "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", -) - -def download_model(url=LAMA_MODEL_URL): +def download_model(url): parts = urlparse(url) hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") diff --git a/lama_cleaner/lama/__init__.py b/lama_cleaner/lama/__init__.py new file mode 100644 index 0000000..2fc0a60 --- /dev/null +++ b/lama_cleaner/lama/__init__.py @@ -0,0 +1,56 @@ +import os +import time + +import cv2 +import torch +import numpy as np + +from lama_cleaner.helper import pad_img_to_modulo, download_model + +LAMA_MODEL_URL = os.environ.get( + "LAMA_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", +) + + +class LaMa: + def __init__(self, device): + self.device = device + + if os.environ.get("LAMA_MODEL"): + model_path = os.environ.get("LAMA_MODEL") + if not os.path.exists(model_path): + raise FileNotFoundError(f"lama torchscript model not found: {model_path}") + else: + model_path = download_model(LAMA_MODEL_URL) + + model = torch.jit.load(model_path, map_location="cpu") + model = model.to(device) + model.eval() + self.model = model + + @torch.no_grad() + def __call__(self, image, mask): + """ + image: [C, H, W] RGB + mask: [1, H, W] + return: BGR IMAGE + """ + device = self.device + origin_height, origin_width = image.shape[1:] + image = pad_img_to_modulo(image, mod=8) + mask = pad_img_to_modulo(mask, mod=8) + + mask = (mask > 0) * 1 + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) + + start = time.time() + inpainted_image = self.model(image, mask) + + print(f"process time: {(time.time() - start) * 1000}ms") + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = cur_res[0:origin_height, 0:origin_width, :] + cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) + return cur_res diff --git a/lama_cleaner/ldm/__init__.py b/lama_cleaner/ldm/__init__.py new file mode 100644 index 0000000..092cab9 --- /dev/null +++ b/lama_cleaner/ldm/__init__.py @@ -0,0 +1,415 @@ +import os + +import numpy as np +import torch + +torch.manual_seed(42) +import torch.nn as nn +from tqdm import tqdm +import cv2 +from lama_cleaner.helper import pad_img_to_modulo, download_model +from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \ + timestep_embedding + +LDM_ENCODE_MODEL_URL = os.environ.get( + "LDM_ENCODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", +) + +LDM_DECODE_MODEL_URL = os.environ.get( + "LDM_DECODE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", +) + +LDM_DIFFUSION_MODEL_URL = os.environ.get( + "LDM_DIFFUSION_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", +) + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + device, + timesteps=1000, + beta_schedule="linear", + linear_start=0.0015, + linear_end=0.0205, + cosine_s=0.008, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + parameterization="eps", # all assuming fixed variance schedules + use_positional_encodings=False): + super().__init__() + self.device = device + self.parameterization = parameterization + self.use_positional_encodings = use_positional_encodings + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(self.device, beta_schedule, timesteps, linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + +class LatentDiffusion(DDPM): + def __init__(self, + diffusion_model, + device, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = 1 + self.scale_by_std = scale_by_std + super().__init__(device, *args, **kwargs) + self.diffusion_model = diffusion_model + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.num_downs = 2 + self.scale_factor = scale_factor + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def apply_model(self, x_noisy, t, cond): + # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128 + t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False) + x_recon = self.diffusion_model(x_noisy, t_emb, cond) + return x_recon + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear"): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + # array([1]) + num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, steps, conditioning, batch_size, shape): + self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # samples: 1,3,128,128 + return self.ddim_sampling(conditioning, + size, + quantize_denoised=False, + ddim_use_original_steps=False, + noise_dropout=0, + temperature=1., + ) + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + ddim_use_original_steps=False, + quantize_denoised=False, + temperature=1., noise_dropout=0.): + device = self.model.betas.device + b = shape[0] + img = torch.randn(shape, device=device) # 用了 + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps # 用了 + + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout) + img, _ = outs + + return img + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0.): + b, *_, device = *x.shape, x.device + e_t = self.model.apply_model(x, t, c) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: # 没用 + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: # 没用 + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + +def load_jit_model(url, device): + model_path = download_model(url) + model = torch.jit.load(model_path).to(device) + model.eval() + return model + + +class LDM: + def __init__(self, device, steps=50): + self.device = device + + self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) + self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) + self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) + + model = LatentDiffusion(self.diffusion_model, device) + self.sampler = DDIMSampler(model) + self.steps = steps + + def _norm(self, tensor): + return tensor * 2.0 - 1.0 + + @torch.no_grad() + def __call__(self, image, mask): + """ + image: [C, H, W] RGB + mask: [1, H, W] + return: BGR IMAGE + """ + # image [1,3,512,512] float32 + # mask: [1,1,512,512] float32 + # masked_image: [1,3,512,512] float32 + origin_height, origin_width = image.shape[1:] + image = pad_img_to_modulo(image, mod=32) + mask = pad_img_to_modulo(mask, mod=32) + padded_height, padded_width = image.shape[1:] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # crop 512 x 512 + if padded_width <= 512 or padded_height <= 512: + np_img = self._forward(image, mask, self.device) + else: + print("Try to zoom in") + # zoom in + # x,y,w,h + # box = self.box_from_bitmap(mask) + box = self.find_main_content(mask) + if box is None: + print("No bbox found") + np_img = self._forward(image, mask, self.device) + else: + print(f"box: {box}") + box_x, box_y, box_w, box_h = box + cx = box_x + box_w // 2 + cy = box_y + box_h // 2 + + w = max(512, box_w) + h = max(512, box_h) + + left = max(cx - w // 2, 0) + top = max(cy - h // 2, 0) + right = min(cx + w // 2, origin_width) + bottom = min(cy + h // 2, origin_height) + + x = left + y = top + w = right - left + h = bottom - top + + crop_img = image[:, int(y):int(y + h), int(x):int(x + w)] + crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)] + + print(f"Apply zoom in size width x height: {crop_img.shape}") + + crop_img_height, crop_img_width = crop_img.shape[1:] + + crop_img = pad_img_to_modulo(crop_img, mod=32) + crop_mask = pad_img_to_modulo(crop_mask, mod=32) + # RGB + np_img = self._forward(crop_img, crop_mask, self.device) + + image = (image.transpose(1, 2, 0) * 255).astype(np.uint8) + image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :] + np_img = image + # BGR to RGB + # np_img = image[:, :, ::-1] + + np_img = np_img[0:origin_height, 0:origin_width, :] + np_img = np_img[:, :, ::-1] + + return np_img + + def _forward(self, image, mask, device): + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) + masked_image = (1 - mask) * image + + image = self._norm(image) + mask = self._norm(mask) + masked_image = self._norm(masked_image) + + c = self.cond_stage_model_encode(masked_image) + + cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128 + c = torch.cat((c, cc), dim=1) # 1,4,128,128 + + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim = self.sampler.sample(steps=self.steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape) + x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32 + + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 + np_img = inpainted.astype(np.uint8) + return np_img + + def find_main_content(self, bitmap: np.ndarray): + th2 = bitmap[0].astype(np.uint8) + row_sum = th2.sum(1) + col_sum = th2.sum(0) + xmin = max(0, np.argwhere(col_sum != 0).min() - 20) + xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1]) + ymin = max(0, np.argwhere(row_sum != 0).min() - 20) + ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0]) + + left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax) + return left, top, right - left, bottom - top + + def box_from_bitmap(self, bitmap): + """ + bitmap: single map with shape (NUM_CLASSES, H, W), + whose values are binarized as {0, 1} + """ + contours, _ = cv2.findContours( + (bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE + ) + + contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True) + num_contours = len(contours) + print(f"contours size: {num_contours}") + if num_contours != 1: + return None + + # x,y,w,h + return cv2.boundingRect(contours[0]) diff --git a/lama_cleaner/ldm/utils.py b/lama_cleaner/ldm/utils.py new file mode 100644 index 0000000..1158a44 --- /dev/null +++ b/lama_cleaner/ldm/utils.py @@ -0,0 +1,86 @@ +import math + +import torch +import numpy as np + + +def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2).to(device) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=device) + + args = timesteps[:, None].float() * freqs[None] + + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/main.py b/main.py index 355c673..ceaae63 100644 --- a/main.py +++ b/main.py @@ -4,14 +4,14 @@ import argparse import io import multiprocessing import os -import time -from distutils.util import strtobool from typing import Union import cv2 -import numpy as np import torch +from lama_cleaner.lama import LaMa +from lama_cleaner.ldm import LDM + try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) @@ -24,13 +24,10 @@ from flask import Flask, request, send_file from flask_cors import CORS from lama_cleaner.helper import ( - download_model, load_img, norm_img, numpy_to_bytes, - pad_img_to_modulo, - resize_max_size, -) + resize_max_size, ) NUM_THREADS = str(multiprocessing.cpu_count()) @@ -55,6 +52,7 @@ device = None @app.route("/inpaint", methods=["POST"]) def process(): input = request.files + # RGB image = load_img(input["image"].read()) original_shape = image.shape interpolation = cv2.INTER_CUBIC @@ -74,14 +72,7 @@ def process(): mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = norm_img(mask) - res_np_img = run(image, mask) - - # resize to original size - # res_np_img = cv2.resize( - # res_np_img, - # dsize=(original_shape[1], original_shape[0]), - # interpolation=interpolation, - # ) + res_np_img = model(image, mask) return send_file( io.BytesIO(numpy_to_bytes(res_np_img)), @@ -96,35 +87,12 @@ def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) -def run(image, mask): - """ - image: [C, H, W] - mask: [1, H, W] - return: BGR IMAGE - """ - origin_height, origin_width = image.shape[1:] - image = pad_img_to_modulo(image, mod=8) - mask = pad_img_to_modulo(mask, mod=8) - - mask = (mask > 0) * 1 - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) - - start = time.time() - with torch.no_grad(): - inpainted_image = model(image, mask) - - print(f"process time: {(time.time() - start)*1000}ms") - cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() - cur_res = cur_res[0:origin_height, 0:origin_width, :] - cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") - cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) - return cur_res - - def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--port", default=8080, type=int) + parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) + parser.add_argument("--ldm-steps", default=50, type=int, help="Steps for DDIM sampling process." + "The larger the value, the better the result, but it will be more time-consuming") parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--debug", action="store_true") return parser.parse_args() @@ -136,16 +104,13 @@ def main(): args = get_args_parser() device = torch.device(args.device) - if os.environ.get("LAMA_MODEL"): - model_path = os.environ.get("LAMA_MODEL") - if not os.path.exists(model_path): - raise FileNotFoundError(f"lama torchscript model not found: {model_path}") + if args.model == "lama": + model = LaMa(device) + elif args.model == "ldm": + model = LDM(device, steps=args.ldm_steps) else: - model_path = download_model() + raise NotImplementedError(f"Not supported model: {args.model}") - model = torch.jit.load(model_path, map_location="cpu") - model = model.to(device) - model.eval() app.run(host="0.0.0.0", port=args.port, debug=args.debug) diff --git a/requirements.txt b/requirements.txt index 67b25fa..2cb9902 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch>=1.8.2 opencv-python flask_cors flask +tqdm