diff --git a/iopaint/model/anytext/anytext_model.py b/iopaint/model/anytext/anytext_model.py
new file mode 100644
index 0000000..e69de29
diff --git a/iopaint/model/anytext/anytext_pipeline.py b/iopaint/model/anytext/anytext_pipeline.py
new file mode 100644
index 0000000..9e82fe0
--- /dev/null
+++ b/iopaint/model/anytext/anytext_pipeline.py
@@ -0,0 +1,395 @@
+"""
+AnyText: Multilingual Visual Text Generation And Editing
+Paper: https://arxiv.org/abs/2311.03054
+Code: https://github.com/tyxsspa/AnyText
+Copyright (c) Alibaba, Inc. and its affiliates.
+"""
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import torch
+import random
+import re
+import numpy as np
+import cv2
+import einops
+import time
+from PIL import ImageFont
+from iopaint.model.anytext.cldm.model import create_model, load_state_dict
+from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
+from iopaint.model.anytext.utils import (
+ resize_image,
+ check_channels,
+ draw_glyph,
+ draw_glyph2,
+)
+
+
+BBOX_MAX_NUM = 8
+PLACE_HOLDER = "*"
+max_chars = 20
+
+
+class AnyTextPipeline:
+ def __init__(self, cfg_path, model_dir, font_path, device, use_fp16=True):
+ self.cfg_path = cfg_path
+ self.model_dir = model_dir
+ self.font_path = font_path
+ self.use_fp16 = use_fp16
+ self.device = device
+ self.init_model()
+
+ """
+ return:
+ result: list of images in numpy.ndarray format
+ rst_code: 0: normal -1: error 1:warning
+ rst_info: string of error or warning
+ debug_info: string for debug, only valid if show_debug=True
+ """
+
+ def __call__(self, input_tensor, **forward_params):
+ tic = time.time()
+ str_warning = ""
+ # get inputs
+ seed = input_tensor.get("seed", -1)
+ if seed == -1:
+ seed = random.randint(0, 99999999)
+ # seed_everything(seed)
+ prompt = input_tensor.get("prompt")
+ draw_pos = input_tensor.get("draw_pos")
+ ori_image = input_tensor.get("ori_image")
+
+ mode = forward_params.get("mode")
+ sort_priority = forward_params.get("sort_priority", "↕")
+ show_debug = forward_params.get("show_debug", False)
+ revise_pos = forward_params.get("revise_pos", False)
+ img_count = forward_params.get("image_count", 4)
+ ddim_steps = forward_params.get("ddim_steps", 20)
+ w = forward_params.get("image_width", 512)
+ h = forward_params.get("image_height", 512)
+ strength = forward_params.get("strength", 1.0)
+ cfg_scale = forward_params.get("cfg_scale", 9.0)
+ eta = forward_params.get("eta", 0.0)
+ a_prompt = forward_params.get(
+ "a_prompt",
+ "best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks",
+ )
+ n_prompt = forward_params.get(
+ "n_prompt",
+ "low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
+ )
+
+ prompt, texts = self.modify_prompt(prompt)
+ if prompt is None and texts is None:
+ return (
+ None,
+ -1,
+ "You have input Chinese prompt but the translator is not loaded!",
+ "",
+ )
+ n_lines = len(texts)
+ if mode in ["text-generation", "gen"]:
+ edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
+ elif mode in ["text-editing", "edit"]:
+ if draw_pos is None or ori_image is None:
+ return (
+ None,
+ -1,
+ "Reference image and position image are needed for text editing!",
+ "",
+ )
+ if isinstance(ori_image, str):
+ ori_image = cv2.imread(ori_image)[..., ::-1]
+ assert (
+ ori_image is not None
+ ), f"Can't read ori_image image from{ori_image}!"
+ elif isinstance(ori_image, torch.Tensor):
+ ori_image = ori_image.cpu().numpy()
+ else:
+ assert isinstance(
+ ori_image, np.ndarray
+ ), f"Unknown format of ori_image: {type(ori_image)}"
+ edit_image = ori_image.clip(1, 255) # for mask reason
+ edit_image = check_channels(edit_image)
+ edit_image = resize_image(
+ edit_image, max_length=768
+ ) # make w h multiple of 64, resize if w or h > max_length
+ h, w = edit_image.shape[:2] # change h, w by input ref_img
+ # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
+ if draw_pos is None:
+ pos_imgs = np.zeros((w, h, 1))
+ if isinstance(draw_pos, str):
+ draw_pos = cv2.imread(draw_pos)[..., ::-1]
+ assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!"
+ pos_imgs = 255 - draw_pos
+ elif isinstance(draw_pos, torch.Tensor):
+ pos_imgs = draw_pos.cpu().numpy()
+ else:
+ assert isinstance(
+ draw_pos, np.ndarray
+ ), f"Unknown format of draw_pos: {type(draw_pos)}"
+ pos_imgs = pos_imgs[..., 0:1]
+ pos_imgs = cv2.convertScaleAbs(pos_imgs)
+ _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
+ # seprate pos_imgs
+ pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
+ if len(pos_imgs) == 0:
+ pos_imgs = [np.zeros((h, w, 1))]
+ if len(pos_imgs) < n_lines:
+ if n_lines == 1 and texts[0] == " ":
+ pass # text-to-image without text
+ else:
+ return (
+ None,
+ -1,
+ f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!",
+ "",
+ )
+ elif len(pos_imgs) > n_lines:
+ str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
+ # get pre_pos, poly_list, hint that needed for anytext
+ pre_pos = []
+ poly_list = []
+ for input_pos in pos_imgs:
+ if input_pos.mean() != 0:
+ input_pos = (
+ input_pos[..., np.newaxis]
+ if len(input_pos.shape) == 2
+ else input_pos
+ )
+ poly, pos_img = self.find_polygon(input_pos)
+ pre_pos += [pos_img / 255.0]
+ poly_list += [poly]
+ else:
+ pre_pos += [np.zeros((h, w, 1))]
+ poly_list += [None]
+ np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
+ # prepare info dict
+ info = {}
+ info["glyphs"] = []
+ info["gly_line"] = []
+ info["positions"] = []
+ info["n_lines"] = [len(texts)] * img_count
+ gly_pos_imgs = []
+ for i in range(len(texts)):
+ text = texts[i]
+ if len(text) > max_chars:
+ str_warning = (
+ f'"{text}" length > max_chars: {max_chars}, will be cut off...'
+ )
+ text = text[:max_chars]
+ gly_scale = 2
+ if pre_pos[i].mean() != 0:
+ gly_line = draw_glyph(self.font, text)
+ glyphs = draw_glyph2(
+ self.font,
+ text,
+ poly_list[i],
+ scale=gly_scale,
+ width=w,
+ height=h,
+ add_space=False,
+ )
+ gly_pos_img = cv2.drawContours(
+ glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
+ )
+ if revise_pos:
+ resize_gly = cv2.resize(
+ glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
+ )
+ new_pos = cv2.morphologyEx(
+ (resize_gly * 255).astype(np.uint8),
+ cv2.MORPH_CLOSE,
+ kernel=np.ones(
+ (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
+ dtype=np.uint8,
+ ),
+ iterations=1,
+ )
+ new_pos = (
+ new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
+ )
+ contours, _ = cv2.findContours(
+ new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
+ )
+ if len(contours) != 1:
+ str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
+ else:
+ rect = cv2.minAreaRect(contours[0])
+ poly = np.int0(cv2.boxPoints(rect))
+ pre_pos[i] = (
+ cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
+ )
+ gly_pos_img = cv2.drawContours(
+ glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
+ )
+ gly_pos_imgs += [gly_pos_img] # for show
+ else:
+ glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
+ gly_line = np.zeros((80, 512, 1))
+ gly_pos_imgs += [
+ np.zeros((h * gly_scale, w * gly_scale, 1))
+ ] # for show
+ pos = pre_pos[i]
+ info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
+ info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
+ info["positions"] += [self.arr2tensor(pos, img_count)]
+ # get masked_x
+ masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
+ masked_img = np.transpose(masked_img, (2, 0, 1))
+ masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
+ if self.use_fp16:
+ masked_img = masked_img.half()
+ encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
+ masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
+ if self.use_fp16:
+ masked_x = masked_x.half()
+ info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
+
+ hint = self.arr2tensor(np_hint, img_count)
+ cond = self.model.get_learned_conditioning(
+ dict(
+ c_concat=[hint],
+ c_crossattn=[[prompt + " , " + a_prompt] * img_count],
+ text_info=info,
+ )
+ )
+ un_cond = self.model.get_learned_conditioning(
+ dict(c_concat=[hint], c_crossattn=[[n_prompt] * img_count], text_info=info)
+ )
+ shape = (4, h // 8, w // 8)
+ self.model.control_scales = [strength] * 13
+ samples, intermediates = self.ddim_sampler.sample(
+ ddim_steps,
+ img_count,
+ shape,
+ cond,
+ verbose=False,
+ eta=eta,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=un_cond,
+ )
+ if self.use_fp16:
+ samples = samples.half()
+ x_samples = self.model.decode_first_stage(samples)
+ x_samples = (
+ (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
+ .cpu()
+ .numpy()
+ .clip(0, 255)
+ .astype(np.uint8)
+ )
+ results = [x_samples[i] for i in range(img_count)]
+ if (
+ mode == "edit" and False
+ ): # replace backgound in text editing but not ideal yet
+ results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
+ results = [r.clip(0, 255).astype(np.uint8) for r in results]
+ if len(gly_pos_imgs) > 0 and show_debug:
+ glyph_bs = np.stack(gly_pos_imgs, axis=2)
+ glyph_img = np.sum(glyph_bs, axis=2) * 255
+ glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
+ results += [np.repeat(glyph_img, 3, axis=2)]
+ # debug_info
+ if not show_debug:
+ debug_info = ""
+ else:
+ input_prompt = prompt
+ for t in texts:
+ input_prompt = input_prompt.replace("*", f'"{t}"', 1)
+ debug_info = f'Prompt: {input_prompt}
\
+ Size: {w}x{h}
\
+ Image Count: {img_count}
\
+ Seed: {seed}
\
+ Use FP16: {self.use_fp16}
\
+ Cost Time: {(time.time()-tic):.2f}s'
+ rst_code = 1 if str_warning else 0
+ return results, rst_code, str_warning, debug_info
+
+ def init_model(self):
+ font_path = self.font_path
+ self.font = ImageFont.truetype(font_path, size=60)
+ cfg_path = self.cfg_path
+ ckpt_path = os.path.join(self.model_dir, "anytext_v1.1.ckpt")
+ clip_path = os.path.join(self.model_dir, "clip-vit-large-patch14")
+ self.model = create_model(
+ cfg_path,
+ device=self.device,
+ cond_stage_path=clip_path,
+ use_fp16=self.use_fp16,
+ )
+ if self.use_fp16:
+ self.model = self.model.half()
+ self.model.load_state_dict(
+ load_state_dict(ckpt_path, location=self.device), strict=False
+ )
+ self.model.eval()
+ self.model = self.model.to(self.device)
+ self.ddim_sampler = DDIMSampler(self.model, device=self.device)
+
+ def modify_prompt(self, prompt):
+ prompt = prompt.replace("“", '"')
+ prompt = prompt.replace("”", '"')
+ p = '"(.*?)"'
+ strs = re.findall(p, prompt)
+ if len(strs) == 0:
+ strs = [" "]
+ else:
+ for s in strs:
+ prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
+ # if self.is_chinese(prompt):
+ # if self.trans_pipe is None:
+ # return None, None
+ # old_prompt = prompt
+ # prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
+ # print(f"Translate: {old_prompt} --> {prompt}")
+ return prompt, strs
+
+ # def is_chinese(self, text):
+ # text = checker._clean_text(text)
+ # for char in text:
+ # cp = ord(char)
+ # if checker._is_chinese_char(cp):
+ # return True
+ # return False
+
+ def separate_pos_imgs(self, img, sort_priority, gap=102):
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
+ components = []
+ for label in range(1, num_labels):
+ component = np.zeros_like(img)
+ component[labels == label] = 255
+ components.append((component, centroids[label]))
+ if sort_priority == "↕":
+ fir, sec = 1, 0 # top-down first
+ elif sort_priority == "↔":
+ fir, sec = 0, 1 # left-right first
+ components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
+ sorted_components = [c[0] for c in components]
+ return sorted_components
+
+ def find_polygon(self, image, min_rect=False):
+ contours, hierarchy = cv2.findContours(
+ image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
+ )
+ max_contour = max(contours, key=cv2.contourArea) # get contour with max area
+ if min_rect:
+ # get minimum enclosing rectangle
+ rect = cv2.minAreaRect(max_contour)
+ poly = np.int0(cv2.boxPoints(rect))
+ else:
+ # get approximate polygon
+ epsilon = 0.01 * cv2.arcLength(max_contour, True)
+ poly = cv2.approxPolyDP(max_contour, epsilon, True)
+ n, _, xy = poly.shape
+ poly = poly.reshape(n, xy)
+ cv2.drawContours(image, [poly], -1, 255, -1)
+ return poly, image
+
+ def arr2tensor(self, arr, bs):
+ arr = np.transpose(arr, (2, 0, 1))
+ _arr = torch.from_numpy(arr.copy()).float().to(self.device)
+ if self.use_fp16:
+ _arr = _arr.half()
+ _arr = torch.stack([_arr for _ in range(bs)], dim=0)
+ return _arr
diff --git a/iopaint/model/anytext/anytext_sd15.yaml b/iopaint/model/anytext/anytext_sd15.yaml
new file mode 100644
index 0000000..a017d90
--- /dev/null
+++ b/iopaint/model/anytext/anytext_sd15.yaml
@@ -0,0 +1,99 @@
+model:
+ target: iopaint.model.anytext.cldm.cldm.ControlLDM
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "img"
+ cond_stage_key: "caption"
+ control_key: "hint"
+ glyph_key: "glyphs"
+ position_key: "positions"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: true # need be true when embedding_manager is valid
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+ only_mid_control: False
+ loss_alpha: 0 # perceptual loss, 0.003
+ loss_beta: 0 # ctc loss
+ latin_weight: 1.0 # latin text line may need smaller weigth
+ with_step_weight: true
+ use_vae_upsample: true
+ embedding_manager_config:
+ target: iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
+ params:
+ valid: true # v6
+ emb_type: ocr # ocr, vit, conv
+ glyph_channels: 1
+ position_channels: 1
+ add_pos: false
+ placeholder_string: '*'
+
+ control_stage_config:
+ target: iopaint.model.anytext.cldm.cldm.ControlNet
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ model_channels: 320
+ glyph_channels: 1
+ position_channels: 1
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ unet_config:
+ target: iopaint.model.anytext.cldm.cldm.ControlledUnetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
+ params:
+ version: ./models/clip-vit-large-patch14
+ use_vision: false # v6
diff --git a/iopaint/model/anytext/cldm/cldm.py b/iopaint/model/anytext/cldm/cldm.py
new file mode 100644
index 0000000..ad9692a
--- /dev/null
+++ b/iopaint/model/anytext/cldm/cldm.py
@@ -0,0 +1,630 @@
+import os
+from pathlib import Path
+
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+import copy
+from easydict import EasyDict as edict
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+ conv_nd,
+ linear,
+ zero_module,
+ timestep_embedding,
+)
+
+from einops import rearrange, repeat
+from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
+from iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+from iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
+from iopaint.model.anytext.ldm.util import log_txt_as_img, exists, instantiate_from_config
+from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
+from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from .recognizer import TextRecognizer, create_predictor
+
+CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+class ControlledUnetModel(UNetModel):
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
+ hs = []
+ with torch.no_grad():
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ if self.use_fp16:
+ t_emb = t_emb.half()
+ emb = self.time_embed(t_emb)
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+
+ if control is not None:
+ h += control.pop()
+
+ for i, module in enumerate(self.output_blocks):
+ if only_mid_control or control is None:
+ h = torch.cat([h, hs.pop()], dim=1)
+ else:
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+ h = module(h, emb, context)
+
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class ControlNet(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ glyph_channels,
+ position_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+ self.dims = dims
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.use_fp16 = use_fp16
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+ self.glyph_block = TimestepEmbedSequential(
+ conv_nd(dims, glyph_channels, 8, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 8, 8, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ )
+
+ self.position_block = TimestepEmbedSequential(
+ conv_nd(dims, position_channels, 8, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 8, 8, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 64, 3, padding=1, stride=2),
+ nn.SiLU(),
+ )
+
+ self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_convs.append(self.make_zero_conv(ch))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ self.zero_convs.append(self.make_zero_conv(ch))
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self.middle_block_out = self.make_zero_conv(ch)
+ self._feature_size += ch
+
+ def make_zero_conv(self, channels):
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+ def forward(self, x, hint, text_info, timesteps, context, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ if self.use_fp16:
+ t_emb = t_emb.half()
+ emb = self.time_embed(t_emb)
+
+ # guided_hint from text_info
+ B, C, H, W = x.shape
+ glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
+ positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
+ enc_glyph = self.glyph_block(glyphs, emb, context)
+ enc_pos = self.position_block(positions, emb, context)
+ guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
+
+ outs = []
+
+ h = x.type(self.dtype)
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ outs.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ outs.append(self.middle_block_out(h, emb, context))
+
+ return outs
+
+
+class ControlLDM(LatentDiffusion):
+
+ def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
+ self.use_fp16 = kwargs.pop('use_fp16', False)
+ super().__init__(*args, **kwargs)
+ self.control_model = instantiate_from_config(control_stage_config)
+ self.control_key = control_key
+ self.glyph_key = glyph_key
+ self.position_key = position_key
+ self.only_mid_control = only_mid_control
+ self.control_scales = [1.0] * 13
+ self.loss_alpha = loss_alpha
+ self.loss_beta = loss_beta
+ self.with_step_weight = with_step_weight
+ self.use_vae_upsample = use_vae_upsample
+ self.latin_weight = latin_weight
+
+ if embedding_manager_config is not None and embedding_manager_config.params.valid:
+ self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
+ for param in self.embedding_manager.embedding_parameters():
+ param.requires_grad = True
+ else:
+ self.embedding_manager = None
+ if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
+ if embedding_manager_config.params.emb_type == 'ocr':
+ self.text_predictor = create_predictor().eval()
+ args = edict()
+ args.rec_image_shape = "3, 48, 320"
+ args.rec_batch_num = 6
+ args.rec_char_dict_path = str(CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt")
+ args.use_fp16 = self.use_fp16
+ self.cn_recognizer = TextRecognizer(args, self.text_predictor)
+ for param in self.text_predictor.parameters():
+ param.requires_grad = False
+ if self.embedding_manager:
+ self.embedding_manager.recog = self.cn_recognizer
+
+ @torch.no_grad()
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
+ if self.embedding_manager is None: # fill in full caption
+ self.fill_caption(batch)
+ x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
+ control = batch[self.control_key] # for log_images and loss_alpha, not real control
+ if bs is not None:
+ control = control[:bs]
+ control = control.to(self.device)
+ control = einops.rearrange(control, 'b h w c -> b c h w')
+ control = control.to(memory_format=torch.contiguous_format).float()
+
+ inv_mask = batch['inv_mask']
+ if bs is not None:
+ inv_mask = inv_mask[:bs]
+ inv_mask = inv_mask.to(self.device)
+ inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
+ inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
+
+ glyphs = batch[self.glyph_key]
+ gly_line = batch['gly_line']
+ positions = batch[self.position_key]
+ n_lines = batch['n_lines']
+ language = batch['language']
+ texts = batch['texts']
+ assert len(glyphs) == len(positions)
+ for i in range(len(glyphs)):
+ if bs is not None:
+ glyphs[i] = glyphs[i][:bs]
+ gly_line[i] = gly_line[i][:bs]
+ positions[i] = positions[i][:bs]
+ n_lines = n_lines[:bs]
+ glyphs[i] = glyphs[i].to(self.device)
+ gly_line[i] = gly_line[i].to(self.device)
+ positions[i] = positions[i].to(self.device)
+ glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
+ gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
+ positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
+ glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
+ gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
+ positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
+ info = {}
+ info['glyphs'] = glyphs
+ info['positions'] = positions
+ info['n_lines'] = n_lines
+ info['language'] = language
+ info['texts'] = texts
+ info['img'] = batch['img'] # nhwc, (-1,1)
+ info['masked_x'] = mx
+ info['gly_line'] = gly_line
+ info['inv_mask'] = inv_mask
+ return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
+
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+ assert isinstance(cond, dict)
+ diffusion_model = self.model.diffusion_model
+ _cond = torch.cat(cond['c_crossattn'], 1)
+ _hint = torch.cat(cond['c_concat'], 1)
+ if self.use_fp16:
+ x_noisy = x_noisy.half()
+ control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
+
+ return eps
+
+ def instantiate_embedding_manager(self, config, embedder):
+ model = instantiate_from_config(config, embedder=embedder)
+ return model
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, N):
+ return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ if self.embedding_manager is not None and c['text_info'] is not None:
+ self.embedding_manager.encode_text(c['text_info'])
+ if isinstance(c, dict):
+ cond_txt = c['c_crossattn'][0]
+ else:
+ cond_txt = c
+ if self.embedding_manager is not None:
+ cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
+ else:
+ cond_txt = self.cond_stage_model.encode(cond_txt)
+ if isinstance(c, dict):
+ c['c_crossattn'][0] = cond_txt
+ else:
+ c = cond_txt
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def fill_caption(self, batch, place_holder='*'):
+ bs = len(batch['n_lines'])
+ cond_list = copy.deepcopy(batch[self.cond_stage_key])
+ for i in range(bs):
+ n_lines = batch['n_lines'][i]
+ if n_lines == 0:
+ continue
+ cur_cap = cond_list[i]
+ for j in range(n_lines):
+ r_txt = batch['texts'][j][i]
+ cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
+ cond_list[i] = cur_cap
+ batch[self.cond_stage_key] = cond_list
+
+ @torch.no_grad()
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
+ if self.cond_stage_trainable:
+ with torch.no_grad():
+ c = self.get_learned_conditioning(c)
+ c_crossattn = c["c_crossattn"][0][:N]
+ c_cat = c["c_concat"][0][:N]
+ text_info = c["text_info"]
+ text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
+ text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
+ text_info['positions'] = [i[:N] for i in text_info['positions']]
+ text_info['n_lines'] = text_info['n_lines'][:N]
+ text_info['masked_x'] = text_info['masked_x'][:N]
+ text_info['img'] = text_info['img'][:N]
+
+ N = min(z.shape[0], N)
+ n_row = min(z.shape[0], n_row)
+ log["reconstruction"] = self.decode_first_stage(z)
+ log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
+ log["control"] = c_cat * 2.0 - 1.0
+ log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
+ # get glyph
+ glyph_bs = torch.stack(text_info['glyphs'])
+ glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
+ log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
+ # fill caption
+ if not self.embedding_manager:
+ self.fill_caption(batch)
+ captions = batch[self.cond_stage_key]
+ log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N)
+ uc_cat = c_cat # torch.zeros_like(c_cat)
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
+ samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ pred_x0 = False # wether log pred_x0
+ if pred_x0:
+ for idx in range(len(tmps['pred_x0'])):
+ pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
+ log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
+
+ return log
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ ddim_sampler = DDIMSampler(self)
+ b, c, h, w = cond["c_concat"][0].shape
+ shape = (self.channels, h // 8, w // 8)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
+ return samples, intermediates
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.control_model.parameters())
+ if self.embedding_manager:
+ params += list(self.embedding_manager.embedding_parameters())
+ if not self.sd_locked:
+ # params += list(self.model.diffusion_model.input_blocks.parameters())
+ # params += list(self.model.diffusion_model.middle_block.parameters())
+ params += list(self.model.diffusion_model.output_blocks.parameters())
+ params += list(self.model.diffusion_model.out.parameters())
+ if self.unlockKV:
+ nCount = 0
+ for name, param in self.model.diffusion_model.named_parameters():
+ if 'attn2.to_k' in name or 'attn2.to_v' in name:
+ params += [param]
+ nCount += 1
+ print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
+
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+ def low_vram_shift(self, is_diffusing):
+ if is_diffusing:
+ self.model = self.model.cuda()
+ self.control_model = self.control_model.cuda()
+ self.first_stage_model = self.first_stage_model.cpu()
+ self.cond_stage_model = self.cond_stage_model.cpu()
+ else:
+ self.model = self.model.cpu()
+ self.control_model = self.control_model.cpu()
+ self.first_stage_model = self.first_stage_model.cuda()
+ self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/iopaint/model/anytext/cldm/ddim_hacked.py b/iopaint/model/anytext/cldm/ddim_hacked.py
new file mode 100644
index 0000000..87ea63b
--- /dev/null
+++ b/iopaint/model/anytext/cldm/ddim_hacked.py
@@ -0,0 +1,486 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+ make_ddim_sampling_parameters,
+ make_ddim_timesteps,
+ noise_like,
+ extract_into_tensor,
+)
+
+
+class DDIMSampler(object):
+ def __init__(self, model, device, schedule="linear", **kwargs):
+ super().__init__()
+ self.device = device
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device(self.device):
+ attr = attr.to(torch.device(self.device))
+ setattr(self, name, attr)
+
+ def make_schedule(
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+ ):
+ self.ddim_timesteps = make_ddim_timesteps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
+ verbose=verbose,
+ )
+ alphas_cumprod = self.model.alphas_cumprod
+ assert (
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+ ), "alphas have to be defined for each timestep"
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
+
+ self.register_buffer("betas", to_torch(self.model.betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer(
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+ )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer(
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+ )
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,
+ verbose=verbose,
+ )
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
+ self.register_buffer("ddim_alphas", ddim_alphas)
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev)
+ / (1 - self.alphas_cumprod)
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+ )
+ self.register_buffer(
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+ )
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.0,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs,
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f"Data shape for DDIM sampling is {size}, eta {eta}")
+
+ samples, intermediates = self.ddim_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(
+ self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ ):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = (
+ self.ddpm_num_timesteps
+ if ddim_use_original_steps
+ else self.ddim_timesteps
+ )
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = (
+ int(
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
+ * self.ddim_timesteps.shape[0]
+ )
+ - 1
+ )
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
+ time_range = (
+ reversed(range(0, timesteps))
+ if ddim_use_original_steps
+ else np.flip(timesteps)
+ )
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(
+ x0, ts
+ ) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(
+ img,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ img, pred_x0 = outs
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates["x_inter"].append(img)
+ intermediates["pred_x0"].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(
+ self,
+ x,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ dynamic_threshold=None,
+ ):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ model_t = self.model.apply_model(x, t, c)
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+ model_output = model_uncond + unconditional_guidance_scale * (
+ model_t - model_uncond
+ )
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", "not implemented"
+ e_t = score_corrector.modify_score(
+ self.model, e_t, x, t, c, **corrector_kwargs
+ )
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = (
+ self.model.alphas_cumprod_prev
+ if use_original_steps
+ else self.ddim_alphas_prev
+ )
+ sqrt_one_minus_alphas = (
+ self.model.sqrt_one_minus_alphas_cumprod
+ if use_original_steps
+ else self.ddim_sqrt_one_minus_alphas
+ )
+ sigmas = (
+ self.model.ddim_sigmas_for_original_num_steps
+ if use_original_steps
+ else self.ddim_sigmas
+ )
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+ )
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(
+ self,
+ x0,
+ c,
+ t_enc,
+ use_original_steps=False,
+ return_intermediates=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ callback=None,
+ ):
+ timesteps = (
+ np.arange(self.ddpm_num_timesteps)
+ if use_original_steps
+ else self.ddim_timesteps
+ )
+ num_reference_steps = timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
+ t = torch.full(
+ (x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
+ )
+ if unconditional_guidance_scale == 1.0:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(
+ torch.cat((x_next, x_next)),
+ torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c)),
+ ),
+ 2,
+ )
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
+ noise_pred - e_t_uncond
+ )
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = (
+ alphas_next[i].sqrt()
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
+ * noise_pred
+ )
+ x_next = xt_weighted + weighted_noise_pred
+ if (
+ return_intermediates
+ and i % (num_steps // return_intermediates) == 0
+ and i < num_steps - 1
+ ):
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback:
+ callback(i)
+
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
+ if return_intermediates:
+ out.update({"intermediates": intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+ )
+
+ @torch.no_grad()
+ def decode(
+ self,
+ x_latent,
+ cond,
+ t_start,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_original_steps=False,
+ callback=None,
+ ):
+ timesteps = (
+ np.arange(self.ddpm_num_timesteps)
+ if use_original_steps
+ else self.ddim_timesteps
+ )
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full(
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+ )
+ x_dec, _ = self.p_sample_ddim(
+ x_dec,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ if callback:
+ callback(i)
+ return x_dec
diff --git a/iopaint/model/anytext/cldm/embedding_manager.py b/iopaint/model/anytext/cldm/embedding_manager.py
new file mode 100644
index 0000000..6ccf8a9
--- /dev/null
+++ b/iopaint/model/anytext/cldm/embedding_manager.py
@@ -0,0 +1,165 @@
+'''
+Copyright (c) Alibaba, Inc. and its affiliates.
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
+
+
+def get_clip_token_for_string(tokenizer, string):
+ batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"]
+ assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
+ return tokens[0, 1]
+
+
+def get_bert_token_for_string(tokenizer, string):
+ token = tokenizer(string)
+ assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
+ token = token[0, 1]
+ return token
+
+
+def get_clip_vision_emb(encoder, processor, img):
+ _img = img.repeat(1, 3, 1, 1)*255
+ inputs = processor(images=_img, return_tensors="pt")
+ inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
+ outputs = encoder(**inputs)
+ emb = outputs.image_embeds
+ return emb
+
+
+def get_recog_emb(encoder, img_list):
+ _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
+ encoder.predictor.eval()
+ _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
+ return preds_neck
+
+
+def pad_H(x):
+ _, _, H, W = x.shape
+ p_top = (W - H) // 2
+ p_bot = W - H - p_top
+ return F.pad(x, (0, 0, p_top, p_bot))
+
+
+class EncodeNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(EncodeNet, self).__init__()
+ chan = 16
+ n_layer = 4 # downsample
+
+ self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
+ self.conv_list = nn.ModuleList([])
+ _c = chan
+ for i in range(n_layer):
+ self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
+ _c *= 2
+ self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.act = nn.SiLU()
+
+ def forward(self, x):
+ x = self.act(self.conv1(x))
+ for layer in self.conv_list:
+ x = self.act(layer(x))
+ x = self.act(self.conv2(x))
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ return x
+
+
+class EmbeddingManager(nn.Module):
+ def __init__(
+ self,
+ embedder,
+ valid=True,
+ glyph_channels=20,
+ position_channels=1,
+ placeholder_string='*',
+ add_pos=False,
+ emb_type='ocr',
+ **kwargs
+ ):
+ super().__init__()
+ if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
+ get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
+ token_dim = 768
+ if hasattr(embedder, 'vit'):
+ assert emb_type == 'vit'
+ self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
+ self.get_recog_emb = None
+ else: # using LDM's BERT encoder
+ get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
+ token_dim = 1280
+ self.token_dim = token_dim
+ self.emb_type = emb_type
+
+ self.add_pos = add_pos
+ if add_pos:
+ self.position_encoder = EncodeNet(position_channels, token_dim)
+ if emb_type == 'ocr':
+ self.proj = linear(40*64, token_dim)
+ if emb_type == 'conv':
+ self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
+
+ self.placeholder_token = get_token_for_string(placeholder_string)
+
+ def encode_text(self, text_info):
+ if self.get_recog_emb is None and self.emb_type == 'ocr':
+ self.get_recog_emb = partial(get_recog_emb, self.recog)
+
+ gline_list = []
+ pos_list = []
+ for i in range(len(text_info['n_lines'])): # sample index in a batch
+ n_lines = text_info['n_lines'][i]
+ for j in range(n_lines): # line
+ gline_list += [text_info['gly_line'][j][i:i+1]]
+ if self.add_pos:
+ pos_list += [text_info['positions'][j][i:i+1]]
+
+ if len(gline_list) > 0:
+ if self.emb_type == 'ocr':
+ recog_emb = self.get_recog_emb(gline_list)
+ enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
+ elif self.emb_type == 'vit':
+ enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
+ elif self.emb_type == 'conv':
+ enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
+ if self.add_pos:
+ enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
+ enc_glyph = enc_glyph+enc_pos
+
+ self.text_embs_all = []
+ n_idx = 0
+ for i in range(len(text_info['n_lines'])): # sample index in a batch
+ n_lines = text_info['n_lines'][i]
+ text_embs = []
+ for j in range(n_lines): # line
+ text_embs += [enc_glyph[n_idx:n_idx+1]]
+ n_idx += 1
+ self.text_embs_all += [text_embs]
+
+ def forward(
+ self,
+ tokenized_text,
+ embedded_text,
+ ):
+ b, device = tokenized_text.shape[0], tokenized_text.device
+ for i in range(b):
+ idx = tokenized_text[i] == self.placeholder_token.to(device)
+ if sum(idx) > 0:
+ if i >= len(self.text_embs_all):
+ print('truncation for log images...')
+ break
+ text_emb = torch.cat(self.text_embs_all[i], dim=0)
+ if sum(idx) != len(text_emb):
+ print('truncation for long caption...')
+ embedded_text[i][idx] = text_emb[:sum(idx)]
+ return embedded_text
+
+ def embedding_parameters(self):
+ return self.parameters()
diff --git a/iopaint/model/anytext/cldm/hack.py b/iopaint/model/anytext/cldm/hack.py
new file mode 100644
index 0000000..05afe5f
--- /dev/null
+++ b/iopaint/model/anytext/cldm/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+
+import iopaint.model.anytext.ldm.modules.encoders.modules
+import iopaint.model.anytext.ldm.modules.attention
+
+from transformers import logging
+from iopaint.model.anytext.ldm.modules.attention import default
+
+
+def disable_verbosity():
+ logging.set_verbosity_error()
+ print('logging improved.')
+ return
+
+
+def enable_sliced_attention():
+ iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
+ print('Enabled sliced_attention.')
+ return
+
+
+def hack_everything(clip_skip=0):
+ disable_verbosity()
+ iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+ iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+ print('Enabled clip hacks.')
+ return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+ PAD = self.tokenizer.pad_token_id
+ EOS = self.tokenizer.eos_token_id
+ BOS = self.tokenizer.bos_token_id
+
+ def tokenize(t):
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+
+ def transformer_encode(t):
+ if self.clip_skip > 1:
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+ else:
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+
+ def split(x):
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+
+ def pad(x, p, i):
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+ raw_tokens_list = tokenize(text)
+ tokens_list = []
+
+ for raw_tokens in raw_tokens_list:
+ raw_tokens_123 = split(raw_tokens)
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+ tokens_list.append(raw_tokens_123)
+
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+ y = transformer_encode(feed)
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+
+ return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ limit = k.shape[0]
+ att_step = 1
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+ q_chunks.reverse()
+ k_chunks.reverse()
+ v_chunks.reverse()
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ del k, q, v
+ for i in range(0, limit, att_step):
+ q_buffer = q_chunks.pop()
+ k_buffer = k_chunks.pop()
+ v_buffer = v_chunks.pop()
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+ del k_buffer, q_buffer
+ # attention, what we cannot get enough of, by chunks
+
+ sim_buffer = sim_buffer.softmax(dim=-1)
+
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+ del v_buffer
+ sim[i:i + att_step, :, :] = sim_buffer
+
+ del sim_buffer
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
diff --git a/iopaint/model/anytext/cldm/model.py b/iopaint/model/anytext/cldm/model.py
new file mode 100644
index 0000000..6d2d2c3
--- /dev/null
+++ b/iopaint/model/anytext/cldm/model.py
@@ -0,0 +1,40 @@
+import os
+import torch
+
+from omegaconf import OmegaConf
+from iopaint.model.anytext.ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+ return d.get("state_dict", d)
+
+
+def load_state_dict(ckpt_path, location="cpu"):
+ _, extension = os.path.splitext(ckpt_path)
+ if extension.lower() == ".safetensors":
+ import safetensors.torch
+
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+ else:
+ state_dict = get_state_dict(
+ torch.load(ckpt_path, map_location=torch.device(location))
+ )
+ state_dict = get_state_dict(state_dict)
+ print(f"Loaded state_dict from [{ckpt_path}]")
+ return state_dict
+
+
+def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
+ config = OmegaConf.load(config_path)
+ if cond_stage_path:
+ config.model.params.cond_stage_config.params.version = (
+ cond_stage_path # use pre-downloaded ckpts, in case blocked
+ )
+ config.model.params.cond_stage_config.params.device = device
+ if use_fp16:
+ config.model.params.use_fp16 = True
+ config.model.params.control_stage_config.params.use_fp16 = True
+ config.model.params.unet_config.params.use_fp16 = True
+ model = instantiate_from_config(config.model).cpu()
+ print(f"Loaded model config from [{config_path}]")
+ return model
diff --git a/iopaint/model/anytext/cldm/recognizer.py b/iopaint/model/anytext/cldm/recognizer.py
new file mode 100755
index 0000000..0621512
--- /dev/null
+++ b/iopaint/model/anytext/cldm/recognizer.py
@@ -0,0 +1,300 @@
+"""
+Copyright (c) Alibaba, Inc. and its affiliates.
+"""
+import os
+import cv2
+import numpy as np
+import math
+import traceback
+from easydict import EasyDict as edict
+import time
+from iopaint.model.anytext.ocr_recog.RecModel import RecModel
+import torch
+import torch.nn.functional as F
+
+
+def min_bounding_rect(img):
+ ret, thresh = cv2.threshold(img, 127, 255, 0)
+ contours, hierarchy = cv2.findContours(
+ thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
+ if len(contours) == 0:
+ print("Bad contours, using fake bbox...")
+ return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
+ max_contour = max(contours, key=cv2.contourArea)
+ rect = cv2.minAreaRect(max_contour)
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ # sort
+ x_sorted = sorted(box, key=lambda x: x[0])
+ left = x_sorted[:2]
+ right = x_sorted[2:]
+ left = sorted(left, key=lambda x: x[1])
+ (tl, bl) = left
+ right = sorted(right, key=lambda x: x[1])
+ (tr, br) = right
+ if tl[1] > bl[1]:
+ (tl, bl) = (bl, tl)
+ if tr[1] > br[1]:
+ (tr, br) = (br, tr)
+ return np.array([tl, tr, br, bl])
+
+
+def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
+ model_file_path = model_dir
+ if model_file_path is not None and not os.path.exists(model_file_path):
+ raise ValueError("not find model file path {}".format(model_file_path))
+
+ if is_onnx:
+ import onnxruntime as ort
+
+ sess = ort.InferenceSession(
+ model_file_path, providers=["CPUExecutionProvider"]
+ ) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
+ return sess
+ else:
+ if model_lang == "ch":
+ n_class = 6625
+ elif model_lang == "en":
+ n_class = 97
+ else:
+ raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
+ rec_config = edict(
+ in_channels=3,
+ backbone=edict(
+ type="MobileNetV1Enhance",
+ scale=0.5,
+ last_conv_stride=[1, 2],
+ last_pool_type="avg",
+ ),
+ neck=edict(
+ type="SequenceEncoder",
+ encoder_type="svtr",
+ dims=64,
+ depth=2,
+ hidden_dims=120,
+ use_guide=True,
+ ),
+ head=edict(
+ type="CTCHead",
+ fc_decay=0.00001,
+ out_channels=n_class,
+ return_feats=True,
+ ),
+ )
+
+ rec_model = RecModel(rec_config)
+ if model_file_path is not None:
+ rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
+ rec_model.eval()
+ return rec_model.eval()
+
+
+def _check_image_file(path):
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
+ return any([path.lower().endswith(e) for e in img_end])
+
+
+def get_image_file_list(img_file):
+ imgs_lists = []
+ if img_file is None or not os.path.exists(img_file):
+ raise Exception("not found any img file in {}".format(img_file))
+ if os.path.isfile(img_file) and _check_image_file(img_file):
+ imgs_lists.append(img_file)
+ elif os.path.isdir(img_file):
+ for single_file in os.listdir(img_file):
+ file_path = os.path.join(img_file, single_file)
+ if os.path.isfile(file_path) and _check_image_file(file_path):
+ imgs_lists.append(file_path)
+ if len(imgs_lists) == 0:
+ raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
+ return imgs_lists
+
+
+class TextRecognizer(object):
+ def __init__(self, args, predictor):
+ self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
+ self.rec_batch_num = args.rec_batch_num
+ self.predictor = predictor
+ self.chars = self.get_char_dict(args.rec_char_dict_path)
+ self.char2id = {x: i for i, x in enumerate(self.chars)}
+ self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
+ self.use_fp16 = args.use_fp16
+
+ # img: CHW
+ def resize_norm_img(self, img, max_wh_ratio):
+ imgC, imgH, imgW = self.rec_image_shape
+ assert imgC == img.shape[0]
+ imgW = int((imgH * max_wh_ratio))
+
+ h, w = img.shape[1:]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = torch.nn.functional.interpolate(
+ img.unsqueeze(0),
+ size=(imgH, resized_w),
+ mode="bilinear",
+ align_corners=True,
+ )
+ resized_image /= 255.0
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
+ padding_im[:, :, 0:resized_w] = resized_image[0]
+ return padding_im
+
+ # img_list: list of tensors with shape chw 0-255
+ def pred_imglist(self, img_list, show_debug=False, is_ori=False):
+ img_num = len(img_list)
+ assert img_num > 0
+ # Calculate the aspect ratio of all text bars
+ width_list = []
+ for img in img_list:
+ width_list.append(img.shape[2] / float(img.shape[1]))
+ # Sorting can speed up the recognition process
+ indices = torch.from_numpy(np.argsort(np.array(width_list)))
+ batch_num = self.rec_batch_num
+ preds_all = [None] * img_num
+ preds_neck_all = [None] * img_num
+ for beg_img_no in range(0, img_num, batch_num):
+ end_img_no = min(img_num, beg_img_no + batch_num)
+ norm_img_batch = []
+
+ imgC, imgH, imgW = self.rec_image_shape[:3]
+ max_wh_ratio = imgW / imgH
+ for ino in range(beg_img_no, end_img_no):
+ h, w = img_list[indices[ino]].shape[1:]
+ if h > w * 1.2:
+ img = img_list[indices[ino]]
+ img = torch.transpose(img, 1, 2).flip(dims=[1])
+ img_list[indices[ino]] = img
+ h, w = img.shape[1:]
+ # wh_ratio = w * 1.0 / h
+ # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
+ for ino in range(beg_img_no, end_img_no):
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
+ if self.use_fp16:
+ norm_img = norm_img.half()
+ norm_img = norm_img.unsqueeze(0)
+ norm_img_batch.append(norm_img)
+ norm_img_batch = torch.cat(norm_img_batch, dim=0)
+ if show_debug:
+ for i in range(len(norm_img_batch)):
+ _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
+ _img = (_img + 0.5) * 255
+ _img = _img[:, :, ::-1]
+ file_name = f"{indices[beg_img_no + i]}"
+ file_name = file_name + "_ori" if is_ori else file_name
+ cv2.imwrite(file_name + ".jpg", _img)
+ if self.is_onnx:
+ input_dict = {}
+ input_dict[self.predictor.get_inputs()[0].name] = (
+ norm_img_batch.detach().cpu().numpy()
+ )
+ outputs = self.predictor.run(None, input_dict)
+ preds = {}
+ preds["ctc"] = torch.from_numpy(outputs[0])
+ preds["ctc_neck"] = [torch.zeros(1)] * img_num
+ else:
+ preds = self.predictor(norm_img_batch)
+ for rno in range(preds["ctc"].shape[0]):
+ preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
+ preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
+
+ return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
+
+ def get_char_dict(self, character_dict_path):
+ character_str = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
+ character_str.append(line)
+ dict_character = list(character_str)
+ dict_character = ["sos"] + dict_character + [" "] # eos is space
+ return dict_character
+
+ def get_text(self, order):
+ char_list = [self.chars[text_id] for text_id in order]
+ return "".join(char_list)
+
+ def decode(self, mat):
+ text_index = mat.detach().cpu().numpy().argmax(axis=1)
+ ignored_tokens = [0]
+ selection = np.ones(len(text_index), dtype=bool)
+ selection[1:] = text_index[1:] != text_index[:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index != ignored_token
+ return text_index[selection], np.where(selection)[0]
+
+ def get_ctcloss(self, preds, gt_text, weight):
+ if not isinstance(weight, torch.Tensor):
+ weight = torch.tensor(weight).to(preds.device)
+ ctc_loss = torch.nn.CTCLoss(reduction="none")
+ log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
+ targets = []
+ target_lengths = []
+ for t in gt_text:
+ targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
+ target_lengths += [len(t)]
+ targets = torch.tensor(targets).to(preds.device)
+ target_lengths = torch.tensor(target_lengths).to(preds.device)
+ input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
+ preds.device
+ )
+ loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
+ loss = loss / input_lengths * weight
+ return loss
+
+
+def main():
+ rec_model_dir = "./ocr_weights/ppv3_rec.pth"
+ predictor = create_predictor(rec_model_dir)
+ args = edict()
+ args.rec_image_shape = "3, 48, 320"
+ args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
+ args.rec_batch_num = 6
+ text_recognizer = TextRecognizer(args, predictor)
+ image_dir = "./test_imgs_cn"
+ gt_text = ["韩国小馆"] * 14
+
+ image_file_list = get_image_file_list(image_dir)
+ valid_image_file_list = []
+ img_list = []
+
+ for image_file in image_file_list:
+ img = cv2.imread(image_file)
+ if img is None:
+ print("error in loading image:{}".format(image_file))
+ continue
+ valid_image_file_list.append(image_file)
+ img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
+ try:
+ tic = time.time()
+ times = []
+ for i in range(10):
+ preds, _ = text_recognizer.pred_imglist(img_list) # get text
+ preds_all = preds.softmax(dim=2)
+ times += [(time.time() - tic) * 1000.0]
+ tic = time.time()
+ print(times)
+ print(np.mean(times[1:]) / len(preds_all))
+ weight = np.ones(len(gt_text))
+ loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
+ for i in range(len(valid_image_file_list)):
+ pred = preds_all[i]
+ order, idx = text_recognizer.decode(pred)
+ text = text_recognizer.get_text(order)
+ print(
+ f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
+ )
+ except Exception as E:
+ print(traceback.format_exc(), E)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/iopaint/model/anytext/ldm/models/autoencoder.py b/iopaint/model/anytext/ldm/models/autoencoder.py
new file mode 100644
index 0000000..20d52e9
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/autoencoder.py
@@ -0,0 +1,218 @@
+import torch
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.model import Encoder, Decoder
+from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from iopaint.model.anytext.ldm.util import instantiate_from_config
+from iopaint.model.anytext.ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(torch.nn.Module):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+ log["reconstructions_ema"] = xrec_ema
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
diff --git a/iopaint/model/anytext/ldm/models/diffusion/__init__.py b/iopaint/model/anytext/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/iopaint/model/anytext/ldm/models/diffusion/ddim.py b/iopaint/model/anytext/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000..f8bbaff
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/ddim.py
@@ -0,0 +1,354 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ # cbs = len(ctmp[0])
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img], "index": [10000]}
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+ intermediates['index'].append(index)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ elif isinstance(c[k], dict):
+ c_in[k] = dict()
+ for key in c[k]:
+ if isinstance(c[k][key], list):
+ if not isinstance(c[k][key][0], torch.Tensor):
+ continue
+ c_in[k][key] = [torch.cat([
+ unconditional_conditioning[k][key][i],
+ c[k][key][i]]) for i in range(len(c[k][key]))]
+ else:
+ c_in[k][key] = torch.cat([
+ unconditional_conditioning[k][key],
+ c[k][key]])
+
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/iopaint/model/anytext/ldm/models/diffusion/ddpm.py b/iopaint/model/anytext/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000..9f48918
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,2380 @@
+"""
+Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from omegaconf import ListConfig
+
+from iopaint.model.anytext.ldm.util import (
+ log_txt_as_img,
+ exists,
+ default,
+ ismap,
+ isimage,
+ mean_flat,
+ count_params,
+ instantiate_from_config,
+)
+from iopaint.model.anytext.ldm.modules.ema import LitEma
+from iopaint.model.anytext.ldm.modules.distributions.distributions import (
+ normal_kl,
+ DiagonalGaussianDistribution,
+)
+from iopaint.model.anytext.ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+ make_beta_schedule,
+ extract_into_tensor,
+ noise_like,
+)
+from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
+import cv2
+
+
+__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
+
+PRINT_DEBUG = False
+
+
+def print_grad(grad):
+ # print('Gradient:', grad)
+ # print(grad.shape)
+ a = grad.max()
+ b = grad.min()
+ # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
+ s = 255.0 / (a - b)
+ c = 255 * (-b / (a - b))
+ grad = grad * s + c
+ # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
+ img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
+ if img.shape[0] == 512:
+ cv2.imwrite("grad-img.jpg", img)
+ elif img.shape[0] == 64:
+ cv2.imwrite("grad-latent.jpg", img)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(torch.nn.Module):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(
+ self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.0,
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.0,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.0,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in [
+ "eps",
+ "x0",
+ "v",
+ ], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(
+ f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
+ )
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema:
+ assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(
+ ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
+ )
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
+ )
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(
+ " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
+ )
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ self.register_schedule(
+ given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ timesteps=timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+ else:
+ self.register_buffer("logvar", logvar)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ # np.save('1.npy', alphas_cumprod)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_timesteps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (
+ 1.0 - alphas_cumprod_prev
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer(
+ "posterior_log_variance_clipped",
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+ )
+ self.register_buffer(
+ "posterior_mean_coef1",
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+ )
+ self.register_buffer(
+ "posterior_mean_coef2",
+ to_torch(
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+ ),
+ )
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas**2 / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ elif self.parameterization == "x0":
+ lvlb_weights = (
+ 0.5
+ * np.sqrt(torch.Tensor(alphas_cumprod))
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
+ )
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(
+ self.betas**2
+ / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ )
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len(
+ [
+ name
+ for name, _ in itertools.chain(
+ self.named_parameters(), self.named_buffers()
+ )
+ ]
+ )
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(), self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params,
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[
+ i % old_shape[0], j % old_shape[1]
+ ]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = (
+ self.load_state_dict(sd, strict=False)
+ if not only_model
+ else self.model.load_state_dict(sd, strict=False)
+ )
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+ * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(
+ x=x, t=t, clip_denoised=clip_denoised
+ )
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(
+ reversed(range(0, self.num_timesteps)),
+ desc="Sampling t",
+ total=self.num_timesteps,
+ ):
+ img = self.p_sample(
+ img,
+ torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised,
+ )
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop(
+ (batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates,
+ )
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == "l1":
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == "l2":
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(
+ f"Parameterization {self.parameterization} not yet supported"
+ )
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = "train" if self.training else "val"
+
+ loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f"{log_prefix}/loss": loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
+ ).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, "b h w c -> b c h w")
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[k][i] = val
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
+ )
+
+ self.log(
+ "global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(
+ loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
+ )
+ self.log_dict(
+ loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
+ )
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(
+ batch_size=N, return_intermediates=True
+ )
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(
+ self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ force_null_conditioning=False,
+ *args,
+ **kwargs,
+ ):
+ self.force_null_conditioning = force_null_conditioning
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = "concat" if concat_mode else "crossattn"
+ if (
+ cond_stage_config == "__is_unconditional__"
+ and not self.force_null_conditioning
+ ):
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
+ )
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(
+ " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
+ )
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ def make_cond_schedule(
+ self,
+ ):
+ self.cond_ids = torch.full(
+ size=(self.num_timesteps,),
+ fill_value=self.num_timesteps - 1,
+ dtype=torch.long,
+ )
+ ids = torch.round(
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
+ ).long()
+ self.cond_ids[: self.num_timesteps_cond] = ids
+
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if (
+ self.scale_by_std
+ and self.current_epoch == 0
+ and self.global_step == 0
+ and batch_idx == 0
+ and not self.restarted_from_ckpt
+ ):
+ assert (
+ self.scale_factor == 1.0
+ ), "rather not use custom rescaling and std-rescaling simultaneously"
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ super().register_schedule(
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
+ )
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != "__is_first_stage__"
+ assert config != "__is_unconditional__"
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(
+ self, samples, desc="", force_no_decoder_quantization=False
+ ):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(
+ self.decode_first_stage(
+ zd.to(self.device), force_not_quantize=force_no_decoder_quantization
+ )
+ )
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+ )
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, "encode") and callable(
+ self.cond_stage_model.encode
+ ):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(
+ torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
+ )[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(
+ weighting,
+ self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"],
+ )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(
+ L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"],
+ )
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(
+ self, x, kernel_size, stride, uf=1, df=1
+ ): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
+ )
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(
+ kernel_size[0], kernel_size[1], Ly, Lx, x.device
+ ).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
+ )
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(
+ kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1,
+ padding=0,
+ stride=(stride[0] * uf, stride[1] * uf),
+ )
+ fold = torch.nn.Fold(
+ output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
+ )
+
+ weighting = self.get_weighting(
+ kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
+ ).to(x.dtype)
+ normalization = fold(weighting).view(
+ 1, 1, h * uf, w * uf
+ ) # normalizes the overlap
+ weighting = weighting.view(
+ (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
+ )
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
+ )
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(
+ kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1,
+ padding=0,
+ stride=(stride[0] // df, stride[1] // df),
+ )
+ fold = torch.nn.Fold(
+ output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
+ )
+
+ weighting = self.get_weighting(
+ kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device
+ ).to(x.dtype)
+ normalization = fold(weighting).view(
+ 1, 1, h // df, w // df
+ ) # normalizes the overlap
+ weighting = weighting.view(
+ (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)
+ )
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(
+ self,
+ batch,
+ k,
+ return_first_stage_outputs=False,
+ force_c_encode=False,
+ cond_key=None,
+ return_original_cond=False,
+ bs=None,
+ return_x=False,
+ mask_k=None,
+ ):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if mask_k is not None:
+ mx = super().get_input(batch, mask_k)
+ if bs is not None:
+ mx = mx[:bs]
+ mx = mx.to(self.device)
+ encoder_posterior = self.encode_first_stage(mx)
+ mx = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ["caption", "coordinates_bbox", "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ["class_label", "cls"]:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {"pos_x": pos_x, "pos_y": pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_x:
+ out.extend([x])
+ if return_original_cond:
+ out.append(xc)
+ if mask_k:
+ out.append(mx)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
+
+ z = 1.0 / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
+
+ z = 1.0 / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
+ ).long()
+ # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = (
+ "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
+ )
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_mean_variance(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised: bool,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ score_corrector=None,
+ corrector_kwargs=None,
+ ):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(
+ self, model_out, x, t, c, **corrector_kwargs
+ )
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised=False,
+ repeat_noise=False,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ ):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(
+ x=x,
+ c=c,
+ t=t,
+ clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ )
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (
+ 0.5 * model_log_variance
+ ).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return (
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+ x0,
+ )
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(
+ self,
+ cond,
+ shape,
+ verbose=True,
+ callback=None,
+ quantize_denoised=False,
+ img_callback=None,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ batch_size=None,
+ x_T=None,
+ start_T=None,
+ log_every_t=None,
+ ):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = (
+ tqdm(
+ reversed(range(0, timesteps)),
+ desc="Progressive Generation",
+ total=timesteps,
+ )
+ if verbose
+ else reversed(range(0, timesteps))
+ )
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(
+ img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised,
+ return_x0=True,
+ temperature=temperature[i],
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ )
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(
+ self,
+ cond,
+ shape,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ start_T=None,
+ log_every_t=None,
+ ):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = (
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
+ if verbose
+ else reversed(range(0, timesteps))
+ )
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(
+ img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised,
+ )
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond,
+ batch_size=16,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ shape=None,
+ **kwargs,
+ ):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+ return self.p_sample_loop(
+ cond,
+ shape,
+ return_intermediates=return_intermediates,
+ x_T=x_T,
+ verbose=verbose,
+ timesteps=timesteps,
+ quantize_denoised=quantize_denoised,
+ mask=mask,
+ x0=x0,
+ )
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(
+ ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
+ )
+
+ else:
+ samples, intermediates = self.sample(
+ cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
+ )
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(
+ batch_size, device=self.device
+ )
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device)
+ else:
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
+ return c
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch,
+ N=8,
+ n_row=4,
+ sample=True,
+ ddim_steps=50,
+ ddim_eta=0.0,
+ return_keys=None,
+ quantize_denoised=True,
+ inpaint=True,
+ plot_denoise_rows=False,
+ plot_progressive_rows=True,
+ plot_diffusion_rows=True,
+ unconditional_guidance_scale=1.0,
+ unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs,
+ ):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(
+ batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N,
+ )
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img(
+ (x.shape[2], x.shape[3]),
+ batch[self.cond_stage_key],
+ size=x.shape[2] // 25,
+ )
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["class_label", "cls"]:
+ try:
+ xc = log_txt_as_img(
+ (x.shape[2], x.shape[3]),
+ batch["human_label"],
+ size=x.shape[2] // 25,
+ )
+ log["conditioning"] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ )
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if (
+ quantize_denoised
+ and not isinstance(self.first_stage_model, AutoencoderKL)
+ and not isinstance(self.first_stage_model, IdentityFirstStage)
+ ):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ quantize_denoised=True,
+ )
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
+ ] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ eta=ddim_eta,
+ ddim_steps=ddim_steps,
+ x0=z[:N],
+ mask=mask,
+ )
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1.0 - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ eta=ddim_eta,
+ ddim_steps=ddim_steps,
+ x0=z[:N],
+ mask=mask,
+ )
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(
+ c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N,
+ )
+ prog_row = self._get_denoise_row_from_list(
+ progressives, desc="Progressive Generation"
+ )
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print("Diffusion model optimizing logvar")
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert "target" in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ "interval": "step",
+ "frequency": 1,
+ }
+ ]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+ return x
+
+
+class DiffusionWrapper(torch.nn.Module):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.sequential_cross_attn = diff_model_config.pop(
+ "sequential_crossattn", False
+ )
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [
+ None,
+ "concat",
+ "crossattn",
+ "hybrid",
+ "adm",
+ "hybrid-adm",
+ "crossattn-adm",
+ ]
+
+ def forward(
+ self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
+ ):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == "concat":
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == "crossattn":
+ if not self.sequential_cross_attn:
+ cc = torch.cat(c_crossattn, 1)
+ else:
+ cc = c_crossattn
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == "hybrid":
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == "hybrid-adm":
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == "crossattn-adm":
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ elif self.conditioning_key == "adm":
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(
+ self,
+ *args,
+ low_scale_config,
+ low_scale_key="LR",
+ noise_level_key=None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ self.noise_level_key = noise_level_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(
+ batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs,
+ )
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, "b h w c -> b c h w")
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ if self.noise_level_key is not None:
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+ raise NotImplementedError("TODO")
+
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch,
+ N=8,
+ n_row=4,
+ sample=True,
+ ddim_steps=200,
+ ddim_eta=1.0,
+ return_keys=None,
+ plot_denoise_rows=False,
+ plot_progressive_rows=True,
+ plot_diffusion_rows=True,
+ unconditional_guidance_scale=1.0,
+ unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs,
+ ):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(
+ batch, self.first_stage_key, bs=N, log_mode=True
+ )
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[
+ f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"
+ ] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img(
+ (x.shape[2], x.shape[3]),
+ batch[self.cond_stage_key],
+ size=x.shape[2] // 25,
+ )
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["class_label", "cls"]:
+ xc = log_txt_as_img(
+ (x.shape[2], x.shape[3]),
+ batch["human_label"],
+ size=x.shape[2] // 25,
+ )
+ log["conditioning"] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ )
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(
+ N, unconditional_guidance_label
+ )
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ uc[k] = c[k]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
+ ] = x_samples_cfg
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(
+ c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N,
+ )
+ prog_row = self._get_denoise_row_from_list(
+ progressives, desc="Progressive Generation"
+ )
+ log["progressive_row"] = prog_row
+
+ return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+ """
+ Basis for different finetunas, such as inpainting or depth2image
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(
+ self,
+ concat_keys: tuple,
+ finetune_keys=(
+ "model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight",
+ ),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args,
+ **kwargs,
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys):
+ assert exists(ckpt_path), "can only finetune from a given checkpoint"
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
+ )
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), "did not find matching parameter to modify"
+ new_entry[:, : self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+
+ missing, unexpected = (
+ self.load_state_dict(sd, strict=False)
+ if not only_model
+ else self.model.load_state_dict(sd, strict=False)
+ )
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch,
+ N=8,
+ n_row=4,
+ sample=True,
+ ddim_steps=200,
+ ddim_eta=1.0,
+ return_keys=None,
+ quantize_denoised=True,
+ inpaint=True,
+ plot_denoise_rows=False,
+ plot_progressive_rows=True,
+ plot_diffusion_rows=True,
+ unconditional_guidance_scale=1.0,
+ unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs,
+ ):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(
+ batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
+ )
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img(
+ (x.shape[2], x.shape[3]),
+ batch[self.cond_stage_key],
+ size=x.shape[2] // 25,
+ )
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["class_label", "cls"]:
+ xc = log_txt_as_img(
+ (x.shape[2], x.shape[3]),
+ batch["human_label"],
+ size=x.shape[2] // 25,
+ )
+ log["conditioning"] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(
+ c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
+ )
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(
+ cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ )
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(
+ N, unconditional_guidance_label
+ )
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(
+ cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
+ ] = x_samples_cfg
+
+ return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(
+ self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(concat_keys, *args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+
+ @torch.no_grad()
+ def get_input(
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
+ ):
+ # note: restricted to non-trainable encoders currently
+ assert (
+ not self.cond_stage_trainable
+ ), "trainable cond stages not yet supported for inpainting"
+ z, c, x, xrec, xc = super().get_input(
+ batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs,
+ )
+
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = (
+ rearrange(batch[ck], "b h w c -> b c h w")
+ .to(memory_format=torch.contiguous_format)
+ .float()
+ )
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+ log["masked_image"] = (
+ rearrange(args[0]["masked_image"], "b h w c -> b c h w")
+ .to(memory_format=torch.contiguous_format)
+ .float()
+ )
+ return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on monocular depth estimation
+ """
+
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.depth_model = instantiate_from_config(depth_stage_config)
+ self.depth_stage_key = concat_keys[0]
+
+ @torch.no_grad()
+ def get_input(
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
+ ):
+ # note: restricted to non-trainable encoders currently
+ assert (
+ not self.cond_stage_trainable
+ ), "trainable cond stages not yet supported for depth2img"
+ z, c, x, xrec, xc = super().get_input(
+ batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs,
+ )
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ cc = self.depth_model(cc)
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ depth_min, depth_max = torch.amin(
+ cc, dim=[1, 2, 3], keepdim=True
+ ), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
+ cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ depth = self.depth_model(args[0][self.depth_stage_key])
+ depth_min, depth_max = torch.amin(
+ depth, dim=[1, 2, 3], keepdim=True
+ ), torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+ log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
+ return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on low-res image (and optionally on some spatial noise augmentation)
+ """
+
+ def __init__(
+ self,
+ concat_keys=("lr",),
+ reshuffle_patch_size=None,
+ low_scale_config=None,
+ low_scale_key=None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.reshuffle_patch_size = reshuffle_patch_size
+ self.low_scale_model = None
+ if low_scale_config is not None:
+ print("Initializing a low-scale model")
+ assert exists(low_scale_key)
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
+ ):
+ # note: restricted to non-trainable encoders currently
+ assert (
+ not self.cond_stage_trainable
+ ), "trainable cond stages not yet supported for upscaling-ft"
+ z, c, x, xrec, xc = super().get_input(
+ batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs,
+ )
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ # optionally make spatial noise_level here
+ c_cat = list()
+ noise_level = None
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ cc = rearrange(cc, "b h w c -> b c h w")
+ if exists(self.reshuffle_patch_size):
+ assert isinstance(self.reshuffle_patch_size, int)
+ cc = rearrange(
+ cc,
+ "b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
+ p1=self.reshuffle_patch_size,
+ p2=self.reshuffle_patch_size,
+ )
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
+ cc, noise_level = self.low_scale_model(cc)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ if exists(noise_level):
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+ else:
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
+ return log
diff --git a/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py b/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000..7427f38
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py b/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000..095e5ba
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+ t = self.inverse_lambda(lambda_t)
+ ===============================================================
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+ 1. For discrete-time DPMs:
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+ 2. For continuous-time DPMs:
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+ ===============================================================
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+ Example:
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+ We support four types of the diffusion model by setting `model_type`:
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+ ===============================================================
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+ =====================================================
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+ =====================================================
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py b/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000..7d137b8
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+ "eps": "noise",
+ "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+ return x.to(device), None
\ No newline at end of file
diff --git a/iopaint/model/anytext/ldm/models/diffusion/plms.py b/iopaint/model/anytext/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000..5f35d55
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from iopaint.model.anytext.ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/iopaint/model/anytext/ldm/models/diffusion/sampling_util.py b/iopaint/model/anytext/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000..7eff02b
--- /dev/null
+++ b/iopaint/model/anytext/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/iopaint/model/anytext/ldm/modules/attention.py b/iopaint/model/anytext/ldm/modules/attention.py
new file mode 100644
index 0000000..df92aa7
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/attention.py
@@ -0,0 +1,360 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
+
+
+# CrossAttn precision handling
+import os
+
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION == "fp32":
+ with torch.autocast(enabled=False, device_type="cuda"):
+ q, k = q.float(), k.float()
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+ else:
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum("b i j, b j d -> b i d", sim, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+ return self.to_out(out)
+
+
+class SDPACrossAttention(CrossAttention):
+ def forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, inner_dim = x.shape
+
+ if mask is not None:
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
+
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+
+ head_dim = inner_dim // h
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+
+ del q_in, k_in, v_in
+
+ dtype = q.dtype
+ if _ATTN_PRECISION == "fp32":
+ q, k, v = q.float(), k.float(), v.float()
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, h * head_dim
+ )
+ hidden_states = hidden_states.to(dtype)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ ):
+ super().__init__()
+
+ if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
+ attn_cls = SDPACrossAttention
+ else:
+ attn_cls = CrossAttention
+
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(
+ self._forward, (x, context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None):
+ x = (
+ self.attn1(
+ self.norm1(x), context=context if self.disable_self_attn else None
+ )
+ + x
+ )
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ use_checkpoint=True,
+ ):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/iopaint/model/anytext/ldm/modules/diffusionmodules/__init__.py b/iopaint/model/anytext/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/iopaint/model/anytext/ldm/modules/diffusionmodules/model.py b/iopaint/model/anytext/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000..3472824
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,973 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock2_0(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ # output: [1, 512, 64, 64]
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+
+ # q = q.reshape(b, c, h * w).transpose()
+ # q = q.permute(0, 2, 1) # b,hw,c
+ # k = k.reshape(b, c, h * w) # b,c,hw
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ # (batch, num_heads, seq_len, head_dim)
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = hidden_states.to(q.dtype)
+
+ h_ = self.proj_out(hidden_states)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ ], f"attn_type {attn_type} unknown"
+ assert attn_kwargs is None
+ if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
+ # print(f"Using torch.nn.functional.scaled_dot_product_attention")
+ return AttnBlock2_0(in_channels)
+ return AttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print(
+ "Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)
+ )
+ )
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList(
+ [
+ nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(
+ in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True),
+ ]
+ )
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1, 2, 3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ ch,
+ num_res_blocks,
+ resolution,
+ ch_mult=(2, 2),
+ dropout=0.0,
+ ):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.res_block1 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(
+ x,
+ size=(
+ int(round(x.shape[2] * self.factor)),
+ int(round(x.shape[3] * self.factor)),
+ ),
+ )
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ ch,
+ resolution,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ ch_mult=(1, 2, 4, 8),
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ num_res_blocks=num_res_blocks,
+ ch=ch,
+ ch_mult=ch_mult,
+ z_channels=intermediate_chn,
+ double_z=False,
+ resolution=resolution,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ out_ch=None,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=intermediate_chn,
+ mid_channels=intermediate_chn,
+ out_channels=out_ch,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(
+ self,
+ z_channels,
+ out_ch,
+ resolution,
+ num_res_blocks,
+ attn_resolutions,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ tmp_chn = z_channels * ch_mult[-1]
+ self.decoder = Decoder(
+ out_ch=out_ch,
+ z_channels=tmp_chn,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ in_channels=None,
+ num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult,
+ resolution=resolution,
+ ch=ch,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=z_channels,
+ mid_channels=tmp_chn,
+ out_channels=tmp_chn,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1.0 + (out_size % in_size)
+ print(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+ )
+ self.rescaler = LatentRescaler(
+ factor=factor_up,
+ in_channels=in_channels,
+ mid_channels=2 * in_channels,
+ out_channels=in_channels,
+ )
+ self.decoder = Decoder(
+ out_ch=out_channels,
+ resolution=out_size,
+ z_channels=in_channels,
+ num_res_blocks=2,
+ attn_resolutions=[],
+ in_channels=None,
+ ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)],
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+ )
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
+ )
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor == 1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
+ )
+ return x
diff --git a/iopaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py b/iopaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000..fd3d6be
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
+from iopaint.model.anytext.ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+ self.use_fp16 = use_fp16
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py b/iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000..5f92630
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from iopaint.model.anytext.ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ return x, None
+
+ def decode(self, x):
+ return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
+
+
+
diff --git a/iopaint/model/anytext/ldm/modules/diffusionmodules/util.py b/iopaint/model/anytext/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000..da29c72
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,271 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from iopaint.model.anytext.ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ # return super().forward(x.float()).type(x.dtype)
+ return super().forward(x).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/iopaint/model/anytext/ldm/modules/distributions/__init__.py b/iopaint/model/anytext/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/iopaint/model/anytext/ldm/modules/distributions/distributions.py b/iopaint/model/anytext/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000..f2b8ef9
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/iopaint/model/anytext/ldm/modules/ema.py b/iopaint/model/anytext/ldm/modules/ema.py
new file mode 100644
index 0000000..bded250
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/iopaint/model/anytext/ldm/modules/encoders/__init__.py b/iopaint/model/anytext/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/iopaint/model/anytext/ldm/modules/encoders/modules.py b/iopaint/model/anytext/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000..e7e2d0a
--- /dev/null
+++ b/iopaint/model/anytext/ldm/modules/encoders/modules.py
@@ -0,0 +1,384 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPVisionModelWithProjection
+
+from iopaint.model.anytext.ldm.util import count_params
+
+
+def _expand_mask(mask, dtype, tgt_len=None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+def _build_causal_attention_mask(bsz, seq_len, dtype):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
+ mask.triu_(1) # zero out the lower diagonal
+ mask = mask.unsqueeze(1) # expand mask
+ return mask
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = [
+ "last",
+ "pooled",
+ "hidden"
+ ]
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate"
+ ]
+
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+class FrozenCLIPEmbedderT3(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ if use_vision:
+ self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
+ self.processor = AutoProcessor.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def embedding_forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ embedding_manager=None,
+ ):
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+ if embedding_manager is not None:
+ inputs_embeds = embedding_manager(input_ids, inputs_embeds)
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+ return embeddings
+
+ self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
+
+ def encoder_forward(
+ self,
+ inputs_embeds,
+ attention_mask=None,
+ causal_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ return hidden_states
+
+ self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
+
+ def text_encoder_forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ embedding_manager=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
+ bsz, seq_len = input_shape
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
+ hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
+ last_hidden_state = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+ return last_hidden_state
+
+ self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
+
+ def transformer_forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ embedding_manager=None,
+ ):
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ embedding_manager=embedding_manager
+ )
+
+ self.transformer.forward = transformer_forward.__get__(self.transformer)
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text, **kwargs):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ z = self.transformer(input_ids=tokens, **kwargs)
+ return z
+
+ def encode(self, text, **kwargs):
+ return self(text, **kwargs)
diff --git a/iopaint/model/anytext/ldm/util.py b/iopaint/model/anytext/ldm/util.py
new file mode 100644
index 0000000..d456a86
--- /dev/null
+++ b/iopaint/model/anytext/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
+ nc = int(32 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config, **kwargs):
+ if "target" not in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+ return loss
\ No newline at end of file
diff --git a/iopaint/model/anytext/main.py b/iopaint/model/anytext/main.py
new file mode 100644
index 0000000..dbafe50
--- /dev/null
+++ b/iopaint/model/anytext/main.py
@@ -0,0 +1,52 @@
+from anytext_pipeline import AnyTextPipeline
+from utils import save_images
+
+seed = 66273235
+# seed_everything(seed)
+
+pipe = AnyTextPipeline(
+ cfg_path="/Users/cwq/code/github/AnyText/anytext/models_yaMl/anytext_sd15.yaml",
+ model_dir="/Users/cwq/.cache/modelscope/hub/damo/cv_anytext_text_generation_editing",
+ # font_path="/Users/cwq/code/github/AnyText/anytext/font/Arial_Unicode.ttf",
+ # font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-VF.ttf",
+ font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
+ use_fp16=False,
+ device="mps",
+)
+
+img_save_folder = "SaveImages"
+params = {
+ "show_debug": True,
+ "image_count": 2,
+ "ddim_steps": 20,
+}
+
+# # 1. text generation
+# mode = "text-generation"
+# input_data = {
+# "prompt": 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream',
+# "seed": seed,
+# "draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/gen9.png",
+# }
+# results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
+# if rtn_code >= 0:
+# save_images(results, img_save_folder)
+# print(f"Done, result images are saved in: {img_save_folder}")
+# if rtn_warning:
+# print(rtn_warning)
+#
+# exit()
+# 2. text editing
+mode = "text-editing"
+input_data = {
+ "prompt": 'A cake with colorful characters that reads "EVERYDAY"',
+ "seed": seed,
+ "draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png",
+ "ori_image": "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg",
+}
+results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
+if rtn_code >= 0:
+ save_images(results, img_save_folder)
+ print(f"Done, result images are saved in: {img_save_folder}")
+if rtn_warning:
+ print(rtn_warning)
diff --git a/iopaint/model/anytext/ocr_recog/RNN.py b/iopaint/model/anytext/ocr_recog/RNN.py
new file mode 100755
index 0000000..cf16855
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/RNN.py
@@ -0,0 +1,210 @@
+from torch import nn
+import torch
+from .RecSVTR import Block
+
+class Swish(nn.Module):
+ def __int__(self):
+ super(Swish, self).__int__()
+
+ def forward(self,x):
+ return x*torch.sigmoid(x)
+
+class Im2Im(nn.Module):
+ def __init__(self, in_channels, **kwargs):
+ super().__init__()
+ self.out_channels = in_channels
+
+ def forward(self, x):
+ return x
+
+class Im2Seq(nn.Module):
+ def __init__(self, in_channels, **kwargs):
+ super().__init__()
+ self.out_channels = in_channels
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # assert H == 1
+ x = x.reshape(B, C, H * W)
+ x = x.permute((0, 2, 1))
+ return x
+
+class EncoderWithRNN(nn.Module):
+ def __init__(self, in_channels,**kwargs):
+ super(EncoderWithRNN, self).__init__()
+ hidden_size = kwargs.get('hidden_size', 256)
+ self.out_channels = hidden_size * 2
+ self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True)
+
+ def forward(self, x):
+ self.lstm.flatten_parameters()
+ x, _ = self.lstm(x)
+ return x
+
+class SequenceEncoder(nn.Module):
+ def __init__(self, in_channels, encoder_type='rnn', **kwargs):
+ super(SequenceEncoder, self).__init__()
+ self.encoder_reshape = Im2Seq(in_channels)
+ self.out_channels = self.encoder_reshape.out_channels
+ self.encoder_type = encoder_type
+ if encoder_type == 'reshape':
+ self.only_reshape = True
+ else:
+ support_encoder_dict = {
+ 'reshape': Im2Seq,
+ 'rnn': EncoderWithRNN,
+ 'svtr': EncoderWithSVTR
+ }
+ assert encoder_type in support_encoder_dict, '{} must in {}'.format(
+ encoder_type, support_encoder_dict.keys())
+
+ self.encoder = support_encoder_dict[encoder_type](
+ self.encoder_reshape.out_channels,**kwargs)
+ self.out_channels = self.encoder.out_channels
+ self.only_reshape = False
+
+ def forward(self, x):
+ if self.encoder_type != 'svtr':
+ x = self.encoder_reshape(x)
+ if not self.only_reshape:
+ x = self.encoder(x)
+ return x
+ else:
+ x = self.encoder(x)
+ x = self.encoder_reshape(x)
+ return x
+
+class ConvBNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias_attr=False,
+ groups=1,
+ act=nn.GELU):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+ bias=bias_attr)
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = Swish()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class EncoderWithSVTR(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ dims=64, # XS
+ depth=2,
+ hidden_dims=120,
+ use_guide=False,
+ num_heads=8,
+ qkv_bias=True,
+ mlp_ratio=2.0,
+ drop_rate=0.1,
+ attn_drop_rate=0.1,
+ drop_path=0.,
+ qk_scale=None):
+ super(EncoderWithSVTR, self).__init__()
+ self.depth = depth
+ self.use_guide = use_guide
+ self.conv1 = ConvBNLayer(
+ in_channels, in_channels // 8, padding=1, act='swish')
+ self.conv2 = ConvBNLayer(
+ in_channels // 8, hidden_dims, kernel_size=1, act='swish')
+
+ self.svtr_block = nn.ModuleList([
+ Block(
+ dim=hidden_dims,
+ num_heads=num_heads,
+ mixer='Global',
+ HW=None,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer='swish',
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-05,
+ prenorm=False) for i in range(depth)
+ ])
+ self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
+ self.conv3 = ConvBNLayer(
+ hidden_dims, in_channels, kernel_size=1, act='swish')
+ # last conv-nxn, the input is concat of input tensor and conv3 output tensor
+ self.conv4 = ConvBNLayer(
+ 2 * in_channels, in_channels // 8, padding=1, act='swish')
+
+ self.conv1x1 = ConvBNLayer(
+ in_channels // 8, dims, kernel_size=1, act='swish')
+ self.out_channels = dims
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ # weight initialization
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.ConvTranspose2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ # for use guide
+ if self.use_guide:
+ z = x.clone()
+ z.stop_gradient = True
+ else:
+ z = x
+ # for short cut
+ h = z
+ # reduce dim
+ z = self.conv1(z)
+ z = self.conv2(z)
+ # SVTR global block
+ B, C, H, W = z.shape
+ z = z.flatten(2).permute(0, 2, 1)
+
+ for blk in self.svtr_block:
+ z = blk(z)
+
+ z = self.norm(z)
+ # last stage
+ z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
+ z = self.conv3(z)
+ z = torch.cat((h, z), dim=1)
+ z = self.conv1x1(self.conv4(z))
+
+ return z
+
+if __name__=="__main__":
+ svtrRNN = EncoderWithSVTR(56)
+ print(svtrRNN)
\ No newline at end of file
diff --git a/iopaint/model/anytext/ocr_recog/RecCTCHead.py b/iopaint/model/anytext/ocr_recog/RecCTCHead.py
new file mode 100755
index 0000000..867ede9
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/RecCTCHead.py
@@ -0,0 +1,48 @@
+from torch import nn
+
+
+class CTCHead(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels=6625,
+ fc_decay=0.0004,
+ mid_channels=None,
+ return_feats=False,
+ **kwargs):
+ super(CTCHead, self).__init__()
+ if mid_channels is None:
+ self.fc = nn.Linear(
+ in_channels,
+ out_channels,
+ bias=True,)
+ else:
+ self.fc1 = nn.Linear(
+ in_channels,
+ mid_channels,
+ bias=True,
+ )
+ self.fc2 = nn.Linear(
+ mid_channels,
+ out_channels,
+ bias=True,
+ )
+
+ self.out_channels = out_channels
+ self.mid_channels = mid_channels
+ self.return_feats = return_feats
+
+ def forward(self, x, labels=None):
+ if self.mid_channels is None:
+ predicts = self.fc(x)
+ else:
+ x = self.fc1(x)
+ predicts = self.fc2(x)
+
+ if self.return_feats:
+ result = dict()
+ result['ctc'] = predicts
+ result['ctc_neck'] = x
+ else:
+ result = predicts
+
+ return result
diff --git a/iopaint/model/anytext/ocr_recog/RecModel.py b/iopaint/model/anytext/ocr_recog/RecModel.py
new file mode 100755
index 0000000..c2313bf
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/RecModel.py
@@ -0,0 +1,45 @@
+from torch import nn
+from .RNN import SequenceEncoder, Im2Seq, Im2Im
+from .RecMv1_enhance import MobileNetV1Enhance
+
+from .RecCTCHead import CTCHead
+
+backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
+neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
+head_dict = {'CTCHead':CTCHead}
+
+
+class RecModel(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ assert 'in_channels' in config, 'in_channels must in model config'
+ backbone_type = config.backbone.pop('type')
+ assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
+ self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
+
+ neck_type = config.neck.pop('type')
+ assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
+ self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
+
+ head_type = config.head.pop('type')
+ assert head_type in head_dict, f'head.type must in {head_dict}'
+ self.head = head_dict[head_type](self.neck.out_channels, **config.head)
+
+ self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
+
+ def load_3rd_state_dict(self, _3rd_name, _state):
+ self.backbone.load_3rd_state_dict(_3rd_name, _state)
+ self.neck.load_3rd_state_dict(_3rd_name, _state)
+ self.head.load_3rd_state_dict(_3rd_name, _state)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ x = self.neck(x)
+ x = self.head(x)
+ return x
+
+ def encode(self, x):
+ x = self.backbone(x)
+ x = self.neck(x)
+ x = self.head.ctc_encoder(x)
+ return x
diff --git a/iopaint/model/anytext/ocr_recog/RecMv1_enhance.py b/iopaint/model/anytext/ocr_recog/RecMv1_enhance.py
new file mode 100644
index 0000000..7529b4a
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/RecMv1_enhance.py
@@ -0,0 +1,232 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .common import Activation
+
+
+class ConvBNLayer(nn.Module):
+ def __init__(self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ channels=None,
+ num_groups=1,
+ act='hard_swish'):
+ super(ConvBNLayer, self).__init__()
+ self.act = act
+ self._conv = nn.Conv2d(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ bias=False)
+
+ self._batch_norm = nn.BatchNorm2d(
+ num_filters,
+ )
+ if self.act is not None:
+ self._act = Activation(act_type=act, inplace=True)
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ if self.act is not None:
+ y = self._act(y)
+ return y
+
+
+class DepthwiseSeparable(nn.Module):
+ def __init__(self,
+ num_channels,
+ num_filters1,
+ num_filters2,
+ num_groups,
+ stride,
+ scale,
+ dw_size=3,
+ padding=1,
+ use_se=False):
+ super(DepthwiseSeparable, self).__init__()
+ self.use_se = use_se
+ self._depthwise_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=int(num_filters1 * scale),
+ filter_size=dw_size,
+ stride=stride,
+ padding=padding,
+ num_groups=int(num_groups * scale))
+ if use_se:
+ self._se = SEModule(int(num_filters1 * scale))
+ self._pointwise_conv = ConvBNLayer(
+ num_channels=int(num_filters1 * scale),
+ filter_size=1,
+ num_filters=int(num_filters2 * scale),
+ stride=1,
+ padding=0)
+
+ def forward(self, inputs):
+ y = self._depthwise_conv(inputs)
+ if self.use_se:
+ y = self._se(y)
+ y = self._pointwise_conv(y)
+ return y
+
+
+class MobileNetV1Enhance(nn.Module):
+ def __init__(self,
+ in_channels=3,
+ scale=0.5,
+ last_conv_stride=1,
+ last_pool_type='max',
+ **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.block_list = []
+
+ self.conv1 = ConvBNLayer(
+ num_channels=in_channels,
+ filter_size=3,
+ channels=3,
+ num_filters=int(32 * scale),
+ stride=2,
+ padding=1)
+
+ conv2_1 = DepthwiseSeparable(
+ num_channels=int(32 * scale),
+ num_filters1=32,
+ num_filters2=64,
+ num_groups=32,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv2_1)
+
+ conv2_2 = DepthwiseSeparable(
+ num_channels=int(64 * scale),
+ num_filters1=64,
+ num_filters2=128,
+ num_groups=64,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv2_2)
+
+ conv3_1 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=128,
+ num_groups=128,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv3_1)
+
+ conv3_2 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=256,
+ num_groups=128,
+ stride=(2, 1),
+ scale=scale)
+ self.block_list.append(conv3_2)
+
+ conv4_1 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=256,
+ num_groups=256,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv4_1)
+
+ conv4_2 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=512,
+ num_groups=256,
+ stride=(2, 1),
+ scale=scale)
+ self.block_list.append(conv4_2)
+
+ for _ in range(5):
+ conv5 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=512,
+ num_groups=512,
+ stride=1,
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=False)
+ self.block_list.append(conv5)
+
+ conv5_6 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=1024,
+ num_groups=512,
+ stride=(2, 1),
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=True)
+ self.block_list.append(conv5_6)
+
+ conv6 = DepthwiseSeparable(
+ num_channels=int(1024 * scale),
+ num_filters1=1024,
+ num_filters2=1024,
+ num_groups=1024,
+ stride=last_conv_stride,
+ dw_size=5,
+ padding=2,
+ use_se=True,
+ scale=scale)
+ self.block_list.append(conv6)
+
+ self.block_list = nn.Sequential(*self.block_list)
+ if last_pool_type == 'avg':
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
+ else:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.out_channels = int(1024 * scale)
+
+ def forward(self, inputs):
+ y = self.conv1(inputs)
+ y = self.block_list(y)
+ y = self.pool(y)
+ return y
+
+def hardsigmoid(x):
+ return F.relu6(x + 3., inplace=True) / 6.
+
+class SEModule(nn.Module):
+ def __init__(self, channel, reduction=4):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True)
+ self.conv2 = nn.Conv2d(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True)
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = hardsigmoid(outputs)
+ x = torch.mul(inputs, outputs)
+
+ return x
diff --git a/iopaint/model/anytext/ocr_recog/RecSVTR.py b/iopaint/model/anytext/ocr_recog/RecSVTR.py
new file mode 100644
index 0000000..484b3df
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/RecSVTR.py
@@ -0,0 +1,591 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.nn.init import trunc_normal_, zeros_, ones_
+from torch.nn import functional
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = torch.tensor(1 - drop_prob)
+ shape = (x.size()[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
+ random_tensor = torch.floor(random_tensor) # binarize
+ output = x.divide(keep_prob) * random_tensor
+ return output
+
+
+class Swish(nn.Module):
+ def __int__(self):
+ super(Swish, self).__int__()
+
+ def forward(self,x):
+ return x*torch.sigmoid(x)
+
+
+class ConvBNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias_attr=False,
+ groups=1,
+ act=nn.GELU):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+ bias=bias_attr)
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, input):
+ return input
+
+
+class Mlp(nn.Module):
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ if isinstance(act_layer, str):
+ self.act = Swish()
+ else:
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvMixer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ HW=(8, 25),
+ local_k=(3, 3), ):
+ super().__init__()
+ self.HW = HW
+ self.dim = dim
+ self.local_mixer = nn.Conv2d(
+ dim,
+ dim,
+ local_k,
+ 1, (local_k[0] // 2, local_k[1] // 2),
+ groups=num_heads,
+ # weight_attr=ParamAttr(initializer=KaimingNormal())
+ )
+
+ def forward(self, x):
+ h = self.HW[0]
+ w = self.HW[1]
+ x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
+ x = self.local_mixer(x)
+ x = x.flatten(2).transpose([0, 2, 1])
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self,
+ dim,
+ num_heads=8,
+ mixer='Global',
+ HW=(8, 25),
+ local_k=(7, 11),
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.HW = HW
+ if HW is not None:
+ H = HW[0]
+ W = HW[1]
+ self.N = H * W
+ self.C = dim
+ if mixer == 'Local' and HW is not None:
+ hk = local_k[0]
+ wk = local_k[1]
+ mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
+ for h in range(0, H):
+ for w in range(0, W):
+ mask[h * W + w, h:h + hk, w:w + wk] = 0.
+ mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
+ 2].flatten(1)
+ mask_inf = torch.full([H * W, H * W],fill_value=float('-inf'))
+ mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
+ self.mask = mask[None,None,:]
+ # self.mask = mask.unsqueeze([0, 1])
+ self.mixer = mixer
+
+ def forward(self, x):
+ if self.HW is not None:
+ N = self.N
+ C = self.C
+ else:
+ _, N, C = x.shape
+ qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4))
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+
+ attn = (q.matmul(k.permute((0, 1, 3, 2))))
+ if self.mixer == 'Local':
+ attn += self.mask
+ attn = functional.softmax(attn, dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self,
+ dim,
+ num_heads,
+ mixer='Global',
+ local_mixer=(7, 11),
+ HW=(8, 25),
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-6,
+ prenorm=True):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
+ else:
+ self.norm1 = norm_layer(dim)
+ if mixer == 'Global' or mixer == 'Local':
+
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ mixer=mixer,
+ HW=HW,
+ local_k=local_mixer,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+ elif mixer == 'Conv':
+ self.mixer = ConvMixer(
+ dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
+ else:
+ raise TypeError("The mixer must be one of [Global, Local, Conv]")
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
+ else:
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+ self.prenorm = prenorm
+
+ def forward(self, x):
+ if self.prenorm:
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self,
+ img_size=(32, 100),
+ in_channels=3,
+ embed_dim=768,
+ sub_num=2):
+ super().__init__()
+ num_patches = (img_size[1] // (2 ** sub_num)) * \
+ (img_size[0] // (2 ** sub_num))
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.embed_dim = embed_dim
+ self.norm = None
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False))
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False))
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).permute(0, 2, 1)
+ return x
+
+
+class SubSample(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ types='Pool',
+ stride=(2, 1),
+ sub_norm='nn.LayerNorm',
+ act=None):
+ super().__init__()
+ self.types = types
+ if types == 'Pool':
+ self.avgpool = nn.AvgPool2d(
+ kernel_size=(3, 5), stride=stride, padding=(1, 2))
+ self.maxpool = nn.MaxPool2d(
+ kernel_size=(3, 5), stride=stride, padding=(1, 2))
+ self.proj = nn.Linear(in_channels, out_channels)
+ else:
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ # weight_attr=ParamAttr(initializer=KaimingNormal())
+ )
+ self.norm = eval(sub_norm)(out_channels)
+ if act is not None:
+ self.act = act()
+ else:
+ self.act = None
+
+ def forward(self, x):
+
+ if self.types == 'Pool':
+ x1 = self.avgpool(x)
+ x2 = self.maxpool(x)
+ x = (x1 + x2) * 0.5
+ out = self.proj(x.flatten(2).permute((0, 2, 1)))
+ else:
+ x = self.conv(x)
+ out = x.flatten(2).permute((0, 2, 1))
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+
+ return out
+
+
+class SVTRNet(nn.Module):
+ def __init__(
+ self,
+ img_size=[48, 100],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=['Local'] * 6 + ['Global'] *
+ 6, # Local atten, Global atten, Conv
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging='Conv', # Conv, Pool, None
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ last_drop=0.1,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer='nn.LayerNorm',
+ sub_norm='nn.LayerNorm',
+ epsilon=1e-6,
+ out_channels=192,
+ out_char_num=25,
+ block_unit='Block',
+ act='nn.GELU',
+ last_stage=True,
+ sub_num=2,
+ prenorm=True,
+ use_lenhead=False,
+ **kwargs):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.out_channels = out_channels
+ self.prenorm = prenorm
+ patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim[0],
+ sub_num=sub_num)
+ num_patches = self.patch_embed.num_patches
+ self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
+ # self.pos_embed = self.create_parameter(
+ # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
+
+ # self.add_parameter("pos_embed", self.pos_embed)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ Block_unit = eval(block_unit)
+
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
+ self.blocks1 = nn.ModuleList(
+ [
+ Block_unit(
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0:depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0:depth[0]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm) for i in range(depth[0])
+ ]
+ )
+ if patch_merging is not None:
+ self.sub_sample1 = SubSample(
+ embed_dim[0],
+ embed_dim[1],
+ sub_norm=sub_norm,
+ stride=[2, 1],
+ types=patch_merging)
+ HW = [self.HW[0] // 2, self.HW[1]]
+ else:
+ HW = self.HW
+ self.patch_merging = patch_merging
+ self.blocks2 = nn.ModuleList([
+ Block_unit(
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0]:depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm) for i in range(depth[1])
+ ])
+ if patch_merging is not None:
+ self.sub_sample2 = SubSample(
+ embed_dim[1],
+ embed_dim[2],
+ sub_norm=sub_norm,
+ stride=[2, 1],
+ types=patch_merging)
+ HW = [self.HW[0] // 4, self.HW[1]]
+ else:
+ HW = self.HW
+ self.blocks3 = nn.ModuleList([
+ Block_unit(
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1]:][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1]:][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm) for i in range(depth[2])
+ ])
+ self.last_stage = last_stage
+ if last_stage:
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
+ self.last_conv = nn.Conv2d(
+ in_channels=embed_dim[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop)
+ if not prenorm:
+ self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
+ self.use_lenhead = use_lenhead
+ if use_lenhead:
+ self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
+ self.hardswish_len = nn.Hardswish()
+ self.dropout_len = nn.Dropout(
+ p=last_drop)
+
+ trunc_normal_(self.pos_embed,std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight,std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks1:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample1(
+ x.permute([0, 2, 1]).reshape(
+ [-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
+ for blk in self.blocks2:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample2(
+ x.permute([0, 2, 1]).reshape(
+ [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
+ for blk in self.blocks3:
+ x = blk(x)
+ if not self.prenorm:
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.use_lenhead:
+ len_x = self.len_conv(x.mean(1))
+ len_x = self.dropout_len(self.hardswish_len(len_x))
+ if self.last_stage:
+ if self.patch_merging is not None:
+ h = self.HW[0] // 4
+ else:
+ h = self.HW[0]
+ x = self.avg_pool(
+ x.permute([0, 2, 1]).reshape(
+ [-1, self.embed_dim[2], h, self.HW[1]]))
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ if self.use_lenhead:
+ return x, len_x
+ return x
+
+
+if __name__=="__main__":
+ a = torch.rand(1,3,48,100)
+ svtr = SVTRNet()
+
+ out = svtr(a)
+ print(svtr)
+ print(out.size())
\ No newline at end of file
diff --git a/iopaint/model/anytext/ocr_recog/common.py b/iopaint/model/anytext/ocr_recog/common.py
new file mode 100644
index 0000000..a328bb0
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/common.py
@@ -0,0 +1,74 @@
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Hswish(nn.Module):
+ def __init__(self, inplace=True):
+ super(Hswish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x * F.relu6(x + 3., inplace=self.inplace) / 6.
+
+# out = max(0, min(1, slop*x+offset))
+# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
+class Hsigmoid(nn.Module):
+ def __init__(self, inplace=True):
+ super(Hsigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
+ # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
+ return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
+
+class GELU(nn.Module):
+ def __init__(self, inplace=True):
+ super(GELU, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return torch.nn.functional.gelu(x)
+
+
+class Swish(nn.Module):
+ def __init__(self, inplace=True):
+ super(Swish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ if self.inplace:
+ x.mul_(torch.sigmoid(x))
+ return x
+ else:
+ return x*torch.sigmoid(x)
+
+
+class Activation(nn.Module):
+ def __init__(self, act_type, inplace=True):
+ super(Activation, self).__init__()
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ self.act = nn.ReLU(inplace=inplace)
+ elif act_type == 'relu6':
+ self.act = nn.ReLU6(inplace=inplace)
+ elif act_type == 'sigmoid':
+ raise NotImplementedError
+ elif act_type == 'hard_sigmoid':
+ self.act = Hsigmoid(inplace)
+ elif act_type == 'hard_swish':
+ self.act = Hswish(inplace=inplace)
+ elif act_type == 'leakyrelu':
+ self.act = nn.LeakyReLU(inplace=inplace)
+ elif act_type == 'gelu':
+ self.act = GELU(inplace=inplace)
+ elif act_type == 'swish':
+ self.act = Swish(inplace=inplace)
+ else:
+ raise NotImplementedError
+
+ def forward(self, inputs):
+ return self.act(inputs)
\ No newline at end of file
diff --git a/iopaint/model/anytext/ocr_recog/en_dict.txt b/iopaint/model/anytext/ocr_recog/en_dict.txt
new file mode 100644
index 0000000..7677d31
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/en_dict.txt
@@ -0,0 +1,95 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+
diff --git a/iopaint/model/anytext/ocr_recog/ppocr_keys_v1.txt b/iopaint/model/anytext/ocr_recog/ppocr_keys_v1.txt
new file mode 100644
index 0000000..84b885d
--- /dev/null
+++ b/iopaint/model/anytext/ocr_recog/ppocr_keys_v1.txt
@@ -0,0 +1,6623 @@
+'
+疗
+绚
+诚
+娇
+溜
+题
+贿
+者
+廖
+更
+纳
+加
+奉
+公
+一
+就
+汴
+计
+与
+路
+房
+原
+妇
+2
+0
+8
+-
+7
+其
+>
+:
+]
+,
+,
+骑
+刈
+全
+消
+昏
+傈
+安
+久
+钟
+嗅
+不
+影
+处
+驽
+蜿
+资
+关
+椤
+地
+瘸
+专
+问
+忖
+票
+嫉
+炎
+韵
+要
+月
+田
+节
+陂
+鄙
+捌
+备
+拳
+伺
+眼
+网
+盎
+大
+傍
+心
+东
+愉
+汇
+蹿
+科
+每
+业
+里
+航
+晏
+字
+平
+录
+先
+1
+3
+彤
+鲶
+产
+稍
+督
+腴
+有
+象
+岳
+注
+绍
+在
+泺
+文
+定
+核
+名
+水
+过
+理
+让
+偷
+率
+等
+这
+发
+”
+为
+含
+肥
+酉
+相
+鄱
+七
+编
+猥
+锛
+日
+镀
+蒂
+掰
+倒
+辆
+栾
+栗
+综
+涩
+州
+雌
+滑
+馀
+了
+机
+块
+司
+宰
+甙
+兴
+矽
+抚
+保
+用
+沧
+秩
+如
+收
+息
+滥
+页
+疑
+埠
+!
+!
+姥
+异
+橹
+钇
+向
+下
+跄
+的
+椴
+沫
+国
+绥
+獠
+报
+开
+民
+蜇
+何
+分
+凇
+长
+讥
+藏
+掏
+施
+羽
+中
+讲
+派
+嘟
+人
+提
+浼
+间
+世
+而
+古
+多
+倪
+唇
+饯
+控
+庚
+首
+赛
+蜓
+味
+断
+制
+觉
+技
+替
+艰
+溢
+潮
+夕
+钺
+外
+摘
+枋
+动
+双
+单
+啮
+户
+枇
+确
+锦
+曜
+杜
+或
+能
+效
+霜
+盒
+然
+侗
+电
+晁
+放
+步
+鹃
+新
+杖
+蜂
+吒
+濂
+瞬
+评
+总
+隍
+对
+独
+合
+也
+是
+府
+青
+天
+诲
+墙
+组
+滴
+级
+邀
+帘
+示
+已
+时
+骸
+仄
+泅
+和
+遨
+店
+雇
+疫
+持
+巍
+踮
+境
+只
+亨
+目
+鉴
+崤
+闲
+体
+泄
+杂
+作
+般
+轰
+化
+解
+迂
+诿
+蛭
+璀
+腾
+告
+版
+服
+省
+师
+小
+规
+程
+线
+海
+办
+引
+二
+桧
+牌
+砺
+洄
+裴
+修
+图
+痫
+胡
+许
+犊
+事
+郛
+基
+柴
+呼
+食
+研
+奶
+律
+蛋
+因
+葆
+察
+戏
+褒
+戒
+再
+李
+骁
+工
+貂
+油
+鹅
+章
+啄
+休
+场
+给
+睡
+纷
+豆
+器
+捎
+说
+敏
+学
+会
+浒
+设
+诊
+格
+廓
+查
+来
+霓
+室
+溆
+¢
+诡
+寥
+焕
+舜
+柒
+狐
+回
+戟
+砾
+厄
+实
+翩
+尿
+五
+入
+径
+惭
+喹
+股
+宇
+篝
+|
+;
+美
+期
+云
+九
+祺
+扮
+靠
+锝
+槌
+系
+企
+酰
+阊
+暂
+蚕
+忻
+豁
+本
+羹
+执
+条
+钦
+H
+獒
+限
+进
+季
+楦
+于
+芘
+玖
+铋
+茯
+未
+答
+粘
+括
+样
+精
+欠
+矢
+甥
+帷
+嵩
+扣
+令
+仔
+风
+皈
+行
+支
+部
+蓉
+刮
+站
+蜡
+救
+钊
+汗
+松
+嫌
+成
+可
+.
+鹤
+院
+从
+交
+政
+怕
+活
+调
+球
+局
+验
+髌
+第
+韫
+谗
+串
+到
+圆
+年
+米
+/
+*
+友
+忿
+检
+区
+看
+自
+敢
+刃
+个
+兹
+弄
+流
+留
+同
+没
+齿
+星
+聆
+轼
+湖
+什
+三
+建
+蛔
+儿
+椋
+汕
+震
+颧
+鲤
+跟
+力
+情
+璺
+铨
+陪
+务
+指
+族
+训
+滦
+鄣
+濮
+扒
+商
+箱
+十
+召
+慷
+辗
+所
+莞
+管
+护
+臭
+横
+硒
+嗓
+接
+侦
+六
+露
+党
+馋
+驾
+剖
+高
+侬
+妪
+幂
+猗
+绺
+骐
+央
+酐
+孝
+筝
+课
+徇
+缰
+门
+男
+西
+项
+句
+谙
+瞒
+秃
+篇
+教
+碲
+罚
+声
+呐
+景
+前
+富
+嘴
+鳌
+稀
+免
+朋
+啬
+睐
+去
+赈
+鱼
+住
+肩
+愕
+速
+旁
+波
+厅
+健
+茼
+厥
+鲟
+谅
+投
+攸
+炔
+数
+方
+击
+呋
+谈
+绩
+别
+愫
+僚
+躬
+鹧
+胪
+炳
+招
+喇
+膨
+泵
+蹦
+毛
+结
+5
+4
+谱
+识
+陕
+粽
+婚
+拟
+构
+且
+搜
+任
+潘
+比
+郢
+妨
+醪
+陀
+桔
+碘
+扎
+选
+哈
+骷
+楷
+亿
+明
+缆
+脯
+监
+睫
+逻
+婵
+共
+赴
+淝
+凡
+惦
+及
+达
+揖
+谩
+澹
+减
+焰
+蛹
+番
+祁
+柏
+员
+禄
+怡
+峤
+龙
+白
+叽
+生
+闯
+起
+细
+装
+谕
+竟
+聚
+钙
+上
+导
+渊
+按
+艾
+辘
+挡
+耒
+盹
+饪
+臀
+记
+邮
+蕙
+受
+各
+医
+搂
+普
+滇
+朗
+茸
+带
+翻
+酚
+(
+光
+堤
+墟
+蔷
+万
+幻
+〓
+瑙
+辈
+昧
+盏
+亘
+蛀
+吉
+铰
+请
+子
+假
+闻
+税
+井
+诩
+哨
+嫂
+好
+面
+琐
+校
+馊
+鬣
+缂
+营
+访
+炖
+占
+农
+缀
+否
+经
+钚
+棵
+趟
+张
+亟
+吏
+茶
+谨
+捻
+论
+迸
+堂
+玉
+信
+吧
+瞠
+乡
+姬
+寺
+咬
+溏
+苄
+皿
+意
+赉
+宝
+尔
+钰
+艺
+特
+唳
+踉
+都
+荣
+倚
+登
+荐
+丧
+奇
+涵
+批
+炭
+近
+符
+傩
+感
+道
+着
+菊
+虹
+仲
+众
+懈
+濯
+颞
+眺
+南
+释
+北
+缝
+标
+既
+茗
+整
+撼
+迤
+贲
+挎
+耱
+拒
+某
+妍
+卫
+哇
+英
+矶
+藩
+治
+他
+元
+领
+膜
+遮
+穗
+蛾
+飞
+荒
+棺
+劫
+么
+市
+火
+温
+拈
+棚
+洼
+转
+果
+奕
+卸
+迪
+伸
+泳
+斗
+邡
+侄
+涨
+屯
+萋
+胭
+氡
+崮
+枞
+惧
+冒
+彩
+斜
+手
+豚
+随
+旭
+淑
+妞
+形
+菌
+吲
+沱
+争
+驯
+歹
+挟
+兆
+柱
+传
+至
+包
+内
+响
+临
+红
+功
+弩
+衡
+寂
+禁
+老
+棍
+耆
+渍
+织
+害
+氵
+渑
+布
+载
+靥
+嗬
+虽
+苹
+咨
+娄
+库
+雉
+榜
+帜
+嘲
+套
+瑚
+亲
+簸
+欧
+边
+6
+腿
+旮
+抛
+吹
+瞳
+得
+镓
+梗
+厨
+继
+漾
+愣
+憨
+士
+策
+窑
+抑
+躯
+襟
+脏
+参
+贸
+言
+干
+绸
+鳄
+穷
+藜
+音
+折
+详
+)
+举
+悍
+甸
+癌
+黎
+谴
+死
+罩
+迁
+寒
+驷
+袖
+媒
+蒋
+掘
+模
+纠
+恣
+观
+祖
+蛆
+碍
+位
+稿
+主
+澧
+跌
+筏
+京
+锏
+帝
+贴
+证
+糠
+才
+黄
+鲸
+略
+炯
+饱
+四
+出
+园
+犀
+牧
+容
+汉
+杆
+浈
+汰
+瑷
+造
+虫
+瘩
+怪
+驴
+济
+应
+花
+沣
+谔
+夙
+旅
+价
+矿
+以
+考
+s
+u
+呦
+晒
+巡
+茅
+准
+肟
+瓴
+詹
+仟
+褂
+译
+桌
+混
+宁
+怦
+郑
+抿
+些
+余
+鄂
+饴
+攒
+珑
+群
+阖
+岔
+琨
+藓
+预
+环
+洮
+岌
+宀
+杲
+瀵
+最
+常
+囡
+周
+踊
+女
+鼓
+袭
+喉
+简
+范
+薯
+遐
+疏
+粱
+黜
+禧
+法
+箔
+斤
+遥
+汝
+奥
+直
+贞
+撑
+置
+绱
+集
+她
+馅
+逗
+钧
+橱
+魉
+[
+恙
+躁
+唤
+9
+旺
+膘
+待
+脾
+惫
+购
+吗
+依
+盲
+度
+瘿
+蠖
+俾
+之
+镗
+拇
+鲵
+厝
+簧
+续
+款
+展
+啃
+表
+剔
+品
+钻
+腭
+损
+清
+锶
+统
+涌
+寸
+滨
+贪
+链
+吠
+冈
+伎
+迥
+咏
+吁
+览
+防
+迅
+失
+汾
+阔
+逵
+绀
+蔑
+列
+川
+凭
+努
+熨
+揪
+利
+俱
+绉
+抢
+鸨
+我
+即
+责
+膦
+易
+毓
+鹊
+刹
+玷
+岿
+空
+嘞
+绊
+排
+术
+估
+锷
+违
+们
+苟
+铜
+播
+肘
+件
+烫
+审
+鲂
+广
+像
+铌
+惰
+铟
+巳
+胍
+鲍
+康
+憧
+色
+恢
+想
+拷
+尤
+疳
+知
+S
+Y
+F
+D
+A
+峄
+裕
+帮
+握
+搔
+氐
+氘
+难
+墒
+沮
+雨
+叁
+缥
+悴
+藐
+湫
+娟
+苑
+稠
+颛
+簇
+后
+阕
+闭
+蕤
+缚
+怎
+佞
+码
+嘤
+蔡
+痊
+舱
+螯
+帕
+赫
+昵
+升
+烬
+岫
+、
+疵
+蜻
+髁
+蕨
+隶
+烛
+械
+丑
+盂
+梁
+强
+鲛
+由
+拘
+揉
+劭
+龟
+撤
+钩
+呕
+孛
+费
+妻
+漂
+求
+阑
+崖
+秤
+甘
+通
+深
+补
+赃
+坎
+床
+啪
+承
+吼
+量
+暇
+钼
+烨
+阂
+擎
+脱
+逮
+称
+P
+神
+属
+矗
+华
+届
+狍
+葑
+汹
+育
+患
+窒
+蛰
+佼
+静
+槎
+运
+鳗
+庆
+逝
+曼
+疱
+克
+代
+官
+此
+麸
+耧
+蚌
+晟
+例
+础
+榛
+副
+测
+唰
+缢
+迹
+灬
+霁
+身
+岁
+赭
+扛
+又
+菡
+乜
+雾
+板
+读
+陷
+徉
+贯
+郁
+虑
+变
+钓
+菜
+圾
+现
+琢
+式
+乐
+维
+渔
+浜
+左
+吾
+脑
+钡
+警
+T
+啵
+拴
+偌
+漱
+湿
+硕
+止
+骼
+魄
+积
+燥
+联
+踢
+玛
+则
+窿
+见
+振
+畿
+送
+班
+钽
+您
+赵
+刨
+印
+讨
+踝
+籍
+谡
+舌
+崧
+汽
+蔽
+沪
+酥
+绒
+怖
+财
+帖
+肱
+私
+莎
+勋
+羔
+霸
+励
+哼
+帐
+将
+帅
+渠
+纪
+婴
+娩
+岭
+厘
+滕
+吻
+伤
+坝
+冠
+戊
+隆
+瘁
+介
+涧
+物
+黍
+并
+姗
+奢
+蹑
+掣
+垸
+锴
+命
+箍
+捉
+病
+辖
+琰
+眭
+迩
+艘
+绌
+繁
+寅
+若
+毋
+思
+诉
+类
+诈
+燮
+轲
+酮
+狂
+重
+反
+职
+筱
+县
+委
+磕
+绣
+奖
+晋
+濉
+志
+徽
+肠
+呈
+獐
+坻
+口
+片
+碰
+几
+村
+柿
+劳
+料
+获
+亩
+惕
+晕
+厌
+号
+罢
+池
+正
+鏖
+煨
+家
+棕
+复
+尝
+懋
+蜥
+锅
+岛
+扰
+队
+坠
+瘾
+钬
+@
+卧
+疣
+镇
+譬
+冰
+彷
+频
+黯
+据
+垄
+采
+八
+缪
+瘫
+型
+熹
+砰
+楠
+襁
+箐
+但
+嘶
+绳
+啤
+拍
+盥
+穆
+傲
+洗
+盯
+塘
+怔
+筛
+丿
+台
+恒
+喂
+葛
+永
+¥
+烟
+酒
+桦
+书
+砂
+蚝
+缉
+态
+瀚
+袄
+圳
+轻
+蛛
+超
+榧
+遛
+姒
+奘
+铮
+右
+荽
+望
+偻
+卡
+丶
+氰
+附
+做
+革
+索
+戚
+坨
+桷
+唁
+垅
+榻
+岐
+偎
+坛
+莨
+山
+殊
+微
+骇
+陈
+爨
+推
+嗝
+驹
+澡
+藁
+呤
+卤
+嘻
+糅
+逛
+侵
+郓
+酌
+德
+摇
+※
+鬃
+被
+慨
+殡
+羸
+昌
+泡
+戛
+鞋
+河
+宪
+沿
+玲
+鲨
+翅
+哽
+源
+铅
+语
+照
+邯
+址
+荃
+佬
+顺
+鸳
+町
+霭
+睾
+瓢
+夸
+椁
+晓
+酿
+痈
+咔
+侏
+券
+噎
+湍
+签
+嚷
+离
+午
+尚
+社
+锤
+背
+孟
+使
+浪
+缦
+潍
+鞅
+军
+姹
+驶
+笑
+鳟
+鲁
+》
+孽
+钜
+绿
+洱
+礴
+焯
+椰
+颖
+囔
+乌
+孔
+巴
+互
+性
+椽
+哞
+聘
+昨
+早
+暮
+胶
+炀
+隧
+低
+彗
+昝
+铁
+呓
+氽
+藉
+喔
+癖
+瑗
+姨
+权
+胱
+韦
+堑
+蜜
+酋
+楝
+砝
+毁
+靓
+歙
+锲
+究
+屋
+喳
+骨
+辨
+碑
+武
+鸠
+宫
+辜
+烊
+适
+坡
+殃
+培
+佩
+供
+走
+蜈
+迟
+翼
+况
+姣
+凛
+浔
+吃
+飘
+债
+犟
+金
+促
+苛
+崇
+坂
+莳
+畔
+绂
+兵
+蠕
+斋
+根
+砍
+亢
+欢
+恬
+崔
+剁
+餐
+榫
+快
+扶
+‖
+濒
+缠
+鳜
+当
+彭
+驭
+浦
+篮
+昀
+锆
+秸
+钳
+弋
+娣
+瞑
+夷
+龛
+苫
+拱
+致
+%
+嵊
+障
+隐
+弑
+初
+娓
+抉
+汩
+累
+蓖
+"
+唬
+助
+苓
+昙
+押
+毙
+破
+城
+郧
+逢
+嚏
+獭
+瞻
+溱
+婿
+赊
+跨
+恼
+璧
+萃
+姻
+貉
+灵
+炉
+密
+氛
+陶
+砸
+谬
+衔
+点
+琛
+沛
+枳
+层
+岱
+诺
+脍
+榈
+埂
+征
+冷
+裁
+打
+蹴
+素
+瘘
+逞
+蛐
+聊
+激
+腱
+萘
+踵
+飒
+蓟
+吆
+取
+咙
+簋
+涓
+矩
+曝
+挺
+揣
+座
+你
+史
+舵
+焱
+尘
+苏
+笈
+脚
+溉
+榨
+诵
+樊
+邓
+焊
+义
+庶
+儋
+蟋
+蒲
+赦
+呷
+杞
+诠
+豪
+还
+试
+颓
+茉
+太
+除
+紫
+逃
+痴
+草
+充
+鳕
+珉
+祗
+墨
+渭
+烩
+蘸
+慕
+璇
+镶
+穴
+嵘
+恶
+骂
+险
+绋
+幕
+碉
+肺
+戳
+刘
+潞
+秣
+纾
+潜
+銮
+洛
+须
+罘
+销
+瘪
+汞
+兮
+屉
+r
+林
+厕
+质
+探
+划
+狸
+殚
+善
+煊
+烹
+〒
+锈
+逯
+宸
+辍
+泱
+柚
+袍
+远
+蹋
+嶙
+绝
+峥
+娥
+缍
+雀
+徵
+认
+镱
+谷
+=
+贩
+勉
+撩
+鄯
+斐
+洋
+非
+祚
+泾
+诒
+饿
+撬
+威
+晷
+搭
+芍
+锥
+笺
+蓦
+候
+琊
+档
+礁
+沼
+卵
+荠
+忑
+朝
+凹
+瑞
+头
+仪
+弧
+孵
+畏
+铆
+突
+衲
+车
+浩
+气
+茂
+悖
+厢
+枕
+酝
+戴
+湾
+邹
+飚
+攘
+锂
+写
+宵
+翁
+岷
+无
+喜
+丈
+挑
+嗟
+绛
+殉
+议
+槽
+具
+醇
+淞
+笃
+郴
+阅
+饼
+底
+壕
+砚
+弈
+询
+缕
+庹
+翟
+零
+筷
+暨
+舟
+闺
+甯
+撞
+麂
+茌
+蔼
+很
+珲
+捕
+棠
+角
+阉
+媛
+娲
+诽
+剿
+尉
+爵
+睬
+韩
+诰
+匣
+危
+糍
+镯
+立
+浏
+阳
+少
+盆
+舔
+擘
+匪
+申
+尬
+铣
+旯
+抖
+赘
+瓯
+居
+ˇ
+哮
+游
+锭
+茏
+歌
+坏
+甚
+秒
+舞
+沙
+仗
+劲
+潺
+阿
+燧
+郭
+嗖
+霏
+忠
+材
+奂
+耐
+跺
+砀
+输
+岖
+媳
+氟
+极
+摆
+灿
+今
+扔
+腻
+枝
+奎
+药
+熄
+吨
+话
+q
+额
+慑
+嘌
+协
+喀
+壳
+埭
+视
+著
+於
+愧
+陲
+翌
+峁
+颅
+佛
+腹
+聋
+侯
+咎
+叟
+秀
+颇
+存
+较
+罪
+哄
+岗
+扫
+栏
+钾
+羌
+己
+璨
+枭
+霉
+煌
+涸
+衿
+键
+镝
+益
+岢
+奏
+连
+夯
+睿
+冥
+均
+糖
+狞
+蹊
+稻
+爸
+刿
+胥
+煜
+丽
+肿
+璃
+掸
+跚
+灾
+垂
+樾
+濑
+乎
+莲
+窄
+犹
+撮
+战
+馄
+软
+络
+显
+鸢
+胸
+宾
+妲
+恕
+埔
+蝌
+份
+遇
+巧
+瞟
+粒
+恰
+剥
+桡
+博
+讯
+凯
+堇
+阶
+滤
+卖
+斌
+骚
+彬
+兑
+磺
+樱
+舷
+两
+娱
+福
+仃
+差
+找
+桁
+÷
+净
+把
+阴
+污
+戬
+雷
+碓
+蕲
+楚
+罡
+焖
+抽
+妫
+咒
+仑
+闱
+尽
+邑
+菁
+爱
+贷
+沥
+鞑
+牡
+嗉
+崴
+骤
+塌
+嗦
+订
+拮
+滓
+捡
+锻
+次
+坪
+杩
+臃
+箬
+融
+珂
+鹗
+宗
+枚
+降
+鸬
+妯
+阄
+堰
+盐
+毅
+必
+杨
+崃
+俺
+甬
+状
+莘
+货
+耸
+菱
+腼
+铸
+唏
+痤
+孚
+澳
+懒
+溅
+翘
+疙
+杷
+淼
+缙
+骰
+喊
+悉
+砻
+坷
+艇
+赁
+界
+谤
+纣
+宴
+晃
+茹
+归
+饭
+梢
+铡
+街
+抄
+肼
+鬟
+苯
+颂
+撷
+戈
+炒
+咆
+茭
+瘙
+负
+仰
+客
+琉
+铢
+封
+卑
+珥
+椿
+镧
+窨
+鬲
+寿
+御
+袤
+铃
+萎
+砖
+餮
+脒
+裳
+肪
+孕
+嫣
+馗
+嵇
+恳
+氯
+江
+石
+褶
+冢
+祸
+阻
+狈
+羞
+银
+靳
+透
+咳
+叼
+敷
+芷
+啥
+它
+瓤
+兰
+痘
+懊
+逑
+肌
+往
+捺
+坊
+甩
+呻
+〃
+沦
+忘
+膻
+祟
+菅
+剧
+崆
+智
+坯
+臧
+霍
+墅
+攻
+眯
+倘
+拢
+骠
+铐
+庭
+岙
+瓠
+′
+缺
+泥
+迢
+捶
+?
+?
+郏
+喙
+掷
+沌
+纯
+秘
+种
+听
+绘
+固
+螨
+团
+香
+盗
+妒
+埚
+蓝
+拖
+旱
+荞
+铀
+血
+遏
+汲
+辰
+叩
+拽
+幅
+硬
+惶
+桀
+漠
+措
+泼
+唑
+齐
+肾
+念
+酱
+虚
+屁
+耶
+旗
+砦
+闵
+婉
+馆
+拭
+绅
+韧
+忏
+窝
+醋
+葺
+顾
+辞
+倜
+堆
+辋
+逆
+玟
+贱
+疾
+董
+惘
+倌
+锕
+淘
+嘀
+莽
+俭
+笏
+绑
+鲷
+杈
+择
+蟀
+粥
+嗯
+驰
+逾
+案
+谪
+褓
+胫
+哩
+昕
+颚
+鲢
+绠
+躺
+鹄
+崂
+儒
+俨
+丝
+尕
+泌
+啊
+萸
+彰
+幺
+吟
+骄
+苣
+弦
+脊
+瑰
+〈
+诛
+镁
+析
+闪
+剪
+侧
+哟
+框
+螃
+守
+嬗
+燕
+狭
+铈
+缮
+概
+迳
+痧
+鲲
+俯
+售
+笼
+痣
+扉
+挖
+满
+咋
+援
+邱
+扇
+歪
+便
+玑
+绦
+峡
+蛇
+叨
+〖
+泽
+胃
+斓
+喋
+怂
+坟
+猪
+该
+蚬
+炕
+弥
+赞
+棣
+晔
+娠
+挲
+狡
+创
+疖
+铕
+镭
+稷
+挫
+弭
+啾
+翔
+粉
+履
+苘
+哦
+楼
+秕
+铂
+土
+锣
+瘟
+挣
+栉
+习
+享
+桢
+袅
+磨
+桂
+谦
+延
+坚
+蔚
+噗
+署
+谟
+猬
+钎
+恐
+嬉
+雒
+倦
+衅
+亏
+璩
+睹
+刻
+殿
+王
+算
+雕
+麻
+丘
+柯
+骆
+丸
+塍
+谚
+添
+鲈
+垓
+桎
+蚯
+芥
+予
+飕
+镦
+谌
+窗
+醚
+菀
+亮
+搪
+莺
+蒿
+羁
+足
+J
+真
+轶
+悬
+衷
+靛
+翊
+掩
+哒
+炅
+掐
+冼
+妮
+l
+谐
+稚
+荆
+擒
+犯
+陵
+虏
+浓
+崽
+刍
+陌
+傻
+孜
+千
+靖
+演
+矜
+钕
+煽
+杰
+酗
+渗
+伞
+栋
+俗
+泫
+戍
+罕
+沾
+疽
+灏
+煦
+芬
+磴
+叱
+阱
+榉
+湃
+蜀
+叉
+醒
+彪
+租
+郡
+篷
+屎
+良
+垢
+隗
+弱
+陨
+峪
+砷
+掴
+颁
+胎
+雯
+绵
+贬
+沐
+撵
+隘
+篙
+暖
+曹
+陡
+栓
+填
+臼
+彦
+瓶
+琪
+潼
+哪
+鸡
+摩
+啦
+俟
+锋
+域
+耻
+蔫
+疯
+纹
+撇
+毒
+绶
+痛
+酯
+忍
+爪
+赳
+歆
+嘹
+辕
+烈
+册
+朴
+钱
+吮
+毯
+癜
+娃
+谀
+邵
+厮
+炽
+璞
+邃
+丐
+追
+词
+瓒
+忆
+轧
+芫
+谯
+喷
+弟
+半
+冕
+裙
+掖
+墉
+绮
+寝
+苔
+势
+顷
+褥
+切
+衮
+君
+佳
+嫒
+蚩
+霞
+佚
+洙
+逊
+镖
+暹
+唛
+&
+殒
+顶
+碗
+獗
+轭
+铺
+蛊
+废
+恹
+汨
+崩
+珍
+那
+杵
+曲
+纺
+夏
+薰
+傀
+闳
+淬
+姘
+舀
+拧
+卷
+楂
+恍
+讪
+厩
+寮
+篪
+赓
+乘
+灭
+盅
+鞣
+沟
+慎
+挂
+饺
+鼾
+杳
+树
+缨
+丛
+絮
+娌
+臻
+嗳
+篡
+侩
+述
+衰
+矛
+圈
+蚜
+匕
+筹
+匿
+濞
+晨
+叶
+骋
+郝
+挚
+蚴
+滞
+增
+侍
+描
+瓣
+吖
+嫦
+蟒
+匾
+圣
+赌
+毡
+癞
+恺
+百
+曳
+需
+篓
+肮
+庖
+帏
+卿
+驿
+遗
+蹬
+鬓
+骡
+歉
+芎
+胳
+屐
+禽
+烦
+晌
+寄
+媾
+狄
+翡
+苒
+船
+廉
+终
+痞
+殇
+々
+畦
+饶
+改
+拆
+悻
+萄
+£
+瓿
+乃
+訾
+桅
+匮
+溧
+拥
+纱
+铍
+骗
+蕃
+龋
+缬
+父
+佐
+疚
+栎
+醍
+掳
+蓄
+x
+惆
+颜
+鲆
+榆
+〔
+猎
+敌
+暴
+谥
+鲫
+贾
+罗
+玻
+缄
+扦
+芪
+癣
+落
+徒
+臾
+恿
+猩
+托
+邴
+肄
+牵
+春
+陛
+耀
+刊
+拓
+蓓
+邳
+堕
+寇
+枉
+淌
+啡
+湄
+兽
+酷
+萼
+碚
+濠
+萤
+夹
+旬
+戮
+梭
+琥
+椭
+昔
+勺
+蜊
+绐
+晚
+孺
+僵
+宣
+摄
+冽
+旨
+萌
+忙
+蚤
+眉
+噼
+蟑
+付
+契
+瓜
+悼
+颡
+壁
+曾
+窕
+颢
+澎
+仿
+俑
+浑
+嵌
+浣
+乍
+碌
+褪
+乱
+蔟
+隙
+玩
+剐
+葫
+箫
+纲
+围
+伐
+决
+伙
+漩
+瑟
+刑
+肓
+镳
+缓
+蹭
+氨
+皓
+典
+畲
+坍
+铑
+檐
+塑
+洞
+倬
+储
+胴
+淳
+戾
+吐
+灼
+惺
+妙
+毕
+珐
+缈
+虱
+盖
+羰
+鸿
+磅
+谓
+髅
+娴
+苴
+唷
+蚣
+霹
+抨
+贤
+唠
+犬
+誓
+逍
+庠
+逼
+麓
+籼
+釉
+呜
+碧
+秧
+氩
+摔
+霄
+穸
+纨
+辟
+妈
+映
+完
+牛
+缴
+嗷
+炊
+恩
+荔
+茆
+掉
+紊
+慌
+莓
+羟
+阙
+萁
+磐
+另
+蕹
+辱
+鳐
+湮
+吡
+吩
+唐
+睦
+垠
+舒
+圜
+冗
+瞿
+溺
+芾
+囱
+匠
+僳
+汐
+菩
+饬
+漓
+黑
+霰
+浸
+濡
+窥
+毂
+蒡
+兢
+驻
+鹉
+芮
+诙
+迫
+雳
+厂
+忐
+臆
+猴
+鸣
+蚪
+栈
+箕
+羡
+渐
+莆
+捍
+眈
+哓
+趴
+蹼
+埕
+嚣
+骛
+宏
+淄
+斑
+噜
+严
+瑛
+垃
+椎
+诱
+压
+庾
+绞
+焘
+廿
+抡
+迄
+棘
+夫
+纬
+锹
+眨
+瞌
+侠
+脐
+竞
+瀑
+孳
+骧
+遁
+姜
+颦
+荪
+滚
+萦
+伪
+逸
+粳
+爬
+锁
+矣
+役
+趣
+洒
+颔
+诏
+逐
+奸
+甭
+惠
+攀
+蹄
+泛
+尼
+拼
+阮
+鹰
+亚
+颈
+惑
+勒
+〉
+际
+肛
+爷
+刚
+钨
+丰
+养
+冶
+鲽
+辉
+蔻
+画
+覆
+皴
+妊
+麦
+返
+醉
+皂
+擀
+〗
+酶
+凑
+粹
+悟
+诀
+硖
+港
+卜
+z
+杀
+涕
+±
+舍
+铠
+抵
+弛
+段
+敝
+镐
+奠
+拂
+轴
+跛
+袱
+e
+t
+沉
+菇
+俎
+薪
+峦
+秭
+蟹
+历
+盟
+菠
+寡
+液
+肢
+喻
+染
+裱
+悱
+抱
+氙
+赤
+捅
+猛
+跑
+氮
+谣
+仁
+尺
+辊
+窍
+烙
+衍
+架
+擦
+倏
+璐
+瑁
+币
+楞
+胖
+夔
+趸
+邛
+惴
+饕
+虔
+蝎
+§
+哉
+贝
+宽
+辫
+炮
+扩
+饲
+籽
+魏
+菟
+锰
+伍
+猝
+末
+琳
+哚
+蛎
+邂
+呀
+姿
+鄞
+却
+歧
+仙
+恸
+椐
+森
+牒
+寤
+袒
+婆
+虢
+雅
+钉
+朵
+贼
+欲
+苞
+寰
+故
+龚
+坭
+嘘
+咫
+礼
+硷
+兀
+睢
+汶
+’
+铲
+烧
+绕
+诃
+浃
+钿
+哺
+柜
+讼
+颊
+璁
+腔
+洽
+咐
+脲
+簌
+筠
+镣
+玮
+鞠
+谁
+兼
+姆
+挥
+梯
+蝴
+谘
+漕
+刷
+躏
+宦
+弼
+b
+垌
+劈
+麟
+莉
+揭
+笙
+渎
+仕
+嗤
+仓
+配
+怏
+抬
+错
+泯
+镊
+孰
+猿
+邪
+仍
+秋
+鼬
+壹
+歇
+吵
+炼
+<
+尧
+射
+柬
+廷
+胧
+霾
+凳
+隋
+肚
+浮
+梦
+祥
+株
+堵
+退
+L
+鹫
+跎
+凶
+毽
+荟
+炫
+栩
+玳
+甜
+沂
+鹿
+顽
+伯
+爹
+赔
+蛴
+徐
+匡
+欣
+狰
+缸
+雹
+蟆
+疤
+默
+沤
+啜
+痂
+衣
+禅
+w
+i
+h
+辽
+葳
+黝
+钗
+停
+沽
+棒
+馨
+颌
+肉
+吴
+硫
+悯
+劾
+娈
+马
+啧
+吊
+悌
+镑
+峭
+帆
+瀣
+涉
+咸
+疸
+滋
+泣
+翦
+拙
+癸
+钥
+蜒
++
+尾
+庄
+凝
+泉
+婢
+渴
+谊
+乞
+陆
+锉
+糊
+鸦
+淮
+I
+B
+N
+晦
+弗
+乔
+庥
+葡
+尻
+席
+橡
+傣
+渣
+拿
+惩
+麋
+斛
+缃
+矮
+蛏
+岘
+鸽
+姐
+膏
+催
+奔
+镒
+喱
+蠡
+摧
+钯
+胤
+柠
+拐
+璋
+鸥
+卢
+荡
+倾
+^
+_
+珀
+逄
+萧
+塾
+掇
+贮
+笆
+聂
+圃
+冲
+嵬
+M
+滔
+笕
+值
+炙
+偶
+蜱
+搐
+梆
+汪
+蔬
+腑
+鸯
+蹇
+敞
+绯
+仨
+祯
+谆
+梧
+糗
+鑫
+啸
+豺
+囹
+猾
+巢
+柄
+瀛
+筑
+踌
+沭
+暗
+苁
+鱿
+蹉
+脂
+蘖
+牢
+热
+木
+吸
+溃
+宠
+序
+泞
+偿
+拜
+檩
+厚
+朐
+毗
+螳
+吞
+媚
+朽
+担
+蝗
+橘
+畴
+祈
+糟
+盱
+隼
+郜
+惜
+珠
+裨
+铵
+焙
+琚
+唯
+咚
+噪
+骊
+丫
+滢
+勤
+棉
+呸
+咣
+淀
+隔
+蕾
+窈
+饨
+挨
+煅
+短
+匙
+粕
+镜
+赣
+撕
+墩
+酬
+馁
+豌
+颐
+抗
+酣
+氓
+佑
+搁
+哭
+递
+耷
+涡
+桃
+贻
+碣
+截
+瘦
+昭
+镌
+蔓
+氚
+甲
+猕
+蕴
+蓬
+散
+拾
+纛
+狼
+猷
+铎
+埋
+旖
+矾
+讳
+囊
+糜
+迈
+粟
+蚂
+紧
+鲳
+瘢
+栽
+稼
+羊
+锄
+斟
+睁
+桥
+瓮
+蹙
+祉
+醺
+鼻
+昱
+剃
+跳
+篱
+跷
+蒜
+翎
+宅
+晖
+嗑
+壑
+峻
+癫
+屏
+狠
+陋
+袜
+途
+憎
+祀
+莹
+滟
+佶
+溥
+臣
+约
+盛
+峰
+磁
+慵
+婪
+拦
+莅
+朕
+鹦
+粲
+裤
+哎
+疡
+嫖
+琵
+窟
+堪
+谛
+嘉
+儡
+鳝
+斩
+郾
+驸
+酊
+妄
+胜
+贺
+徙
+傅
+噌
+钢
+栅
+庇
+恋
+匝
+巯
+邈
+尸
+锚
+粗
+佟
+蛟
+薹
+纵
+蚊
+郅
+绢
+锐
+苗
+俞
+篆
+淆
+膀
+鲜
+煎
+诶
+秽
+寻
+涮
+刺
+怀
+噶
+巨
+褰
+魅
+灶
+灌
+桉
+藕
+谜
+舸
+薄
+搀
+恽
+借
+牯
+痉
+渥
+愿
+亓
+耘
+杠
+柩
+锔
+蚶
+钣
+珈
+喘
+蹒
+幽
+赐
+稗
+晤
+莱
+泔
+扯
+肯
+菪
+裆
+腩
+豉
+疆
+骜
+腐
+倭
+珏
+唔
+粮
+亡
+润
+慰
+伽
+橄
+玄
+誉
+醐
+胆
+龊
+粼
+塬
+陇
+彼
+削
+嗣
+绾
+芽
+妗
+垭
+瘴
+爽
+薏
+寨
+龈
+泠
+弹
+赢
+漪
+猫
+嘧
+涂
+恤
+圭
+茧
+烽
+屑
+痕
+巾
+赖
+荸
+凰
+腮
+畈
+亵
+蹲
+偃
+苇
+澜
+艮
+换
+骺
+烘
+苕
+梓
+颉
+肇
+哗
+悄
+氤
+涠
+葬
+屠
+鹭
+植
+竺
+佯
+诣
+鲇
+瘀
+鲅
+邦
+移
+滁
+冯
+耕
+癔
+戌
+茬
+沁
+巩
+悠
+湘
+洪
+痹
+锟
+循
+谋
+腕
+鳃
+钠
+捞
+焉
+迎
+碱
+伫
+急
+榷
+奈
+邝
+卯
+辄
+皲
+卟
+醛
+畹
+忧
+稳
+雄
+昼
+缩
+阈
+睑
+扌
+耗
+曦
+涅
+捏
+瞧
+邕
+淖
+漉
+铝
+耦
+禹
+湛
+喽
+莼
+琅
+诸
+苎
+纂
+硅
+始
+嗨
+傥
+燃
+臂
+赅
+嘈
+呆
+贵
+屹
+壮
+肋
+亍
+蚀
+卅
+豹
+腆
+邬
+迭
+浊
+}
+童
+螂
+捐
+圩
+勐
+触
+寞
+汊
+壤
+荫
+膺
+渌
+芳
+懿
+遴
+螈
+泰
+蓼
+蛤
+茜
+舅
+枫
+朔
+膝
+眙
+避
+梅
+判
+鹜
+璜
+牍
+缅
+垫
+藻
+黔
+侥
+惚
+懂
+踩
+腰
+腈
+札
+丞
+唾
+慈
+顿
+摹
+荻
+琬
+~
+斧
+沈
+滂
+胁
+胀
+幄
+莜
+Z
+匀
+鄄
+掌
+绰
+茎
+焚
+赋
+萱
+谑
+汁
+铒
+瞎
+夺
+蜗
+野
+娆
+冀
+弯
+篁
+懵
+灞
+隽
+芡
+脘
+俐
+辩
+芯
+掺
+喏
+膈
+蝈
+觐
+悚
+踹
+蔗
+熠
+鼠
+呵
+抓
+橼
+峨
+畜
+缔
+禾
+崭
+弃
+熊
+摒
+凸
+拗
+穹
+蒙
+抒
+祛
+劝
+闫
+扳
+阵
+醌
+踪
+喵
+侣
+搬
+仅
+荧
+赎
+蝾
+琦
+买
+婧
+瞄
+寓
+皎
+冻
+赝
+箩
+莫
+瞰
+郊
+笫
+姝
+筒
+枪
+遣
+煸
+袋
+舆
+痱
+涛
+母
+〇
+启
+践
+耙
+绲
+盘
+遂
+昊
+搞
+槿
+诬
+纰
+泓
+惨
+檬
+亻
+越
+C
+o
+憩
+熵
+祷
+钒
+暧
+塔
+阗
+胰
+咄
+娶
+魔
+琶
+钞
+邻
+扬
+杉
+殴
+咽
+弓
+〆
+髻
+】
+吭
+揽
+霆
+拄
+殖
+脆
+彻
+岩
+芝
+勃
+辣
+剌
+钝
+嘎
+甄
+佘
+皖
+伦
+授
+徕
+憔
+挪
+皇
+庞
+稔
+芜
+踏
+溴
+兖
+卒
+擢
+饥
+鳞
+煲
+‰
+账
+颗
+叻
+斯
+捧
+鳍
+琮
+讹
+蛙
+纽
+谭
+酸
+兔
+莒
+睇
+伟
+觑
+羲
+嗜
+宜
+褐
+旎
+辛
+卦
+诘
+筋
+鎏
+溪
+挛
+熔
+阜
+晰
+鳅
+丢
+奚
+灸
+呱
+献
+陉
+黛
+鸪
+甾
+萨
+疮
+拯
+洲
+疹
+辑
+叙
+恻
+谒
+允
+柔
+烂
+氏
+逅
+漆
+拎
+惋
+扈
+湟
+纭
+啕
+掬
+擞
+哥
+忽
+涤
+鸵
+靡
+郗
+瓷
+扁
+廊
+怨
+雏
+钮
+敦
+E
+懦
+憋
+汀
+拚
+啉
+腌
+岸
+f
+痼
+瞅
+尊
+咀
+眩
+飙
+忌
+仝
+迦
+熬
+毫
+胯
+篑
+茄
+腺
+凄
+舛
+碴
+锵
+诧
+羯
+後
+漏
+汤
+宓
+仞
+蚁
+壶
+谰
+皑
+铄
+棰
+罔
+辅
+晶
+苦
+牟
+闽
+\
+烃
+饮
+聿
+丙
+蛳
+朱
+煤
+涔
+鳖
+犁
+罐
+荼
+砒
+淦
+妤
+黏
+戎
+孑
+婕
+瑾
+戢
+钵
+枣
+捋
+砥
+衩
+狙
+桠
+稣
+阎
+肃
+梏
+诫
+孪
+昶
+婊
+衫
+嗔
+侃
+塞
+蜃
+樵
+峒
+貌
+屿
+欺
+缫
+阐
+栖
+诟
+珞
+荭
+吝
+萍
+嗽
+恂
+啻
+蜴
+磬
+峋
+俸
+豫
+谎
+徊
+镍
+韬
+魇
+晴
+U
+囟
+猜
+蛮
+坐
+囿
+伴
+亭
+肝
+佗
+蝠
+妃
+胞
+滩
+榴
+氖
+垩
+苋
+砣
+扪
+馏
+姓
+轩
+厉
+夥
+侈
+禀
+垒
+岑
+赏
+钛
+辐
+痔
+披
+纸
+碳
+“
+坞
+蠓
+挤
+荥
+沅
+悔
+铧
+帼
+蒌
+蝇
+a
+p
+y
+n
+g
+哀
+浆
+瑶
+凿
+桶
+馈
+皮
+奴
+苜
+佤
+伶
+晗
+铱
+炬
+优
+弊
+氢
+恃
+甫
+攥
+端
+锌
+灰
+稹
+炝
+曙
+邋
+亥
+眶
+碾
+拉
+萝
+绔
+捷
+浍
+腋
+姑
+菖
+凌
+涞
+麽
+锢
+桨
+潢
+绎
+镰
+殆
+锑
+渝
+铬
+困
+绽
+觎
+匈
+糙
+暑
+裹
+鸟
+盔
+肽
+迷
+綦
+『
+亳
+佝
+俘
+钴
+觇
+骥
+仆
+疝
+跪
+婶
+郯
+瀹
+唉
+脖
+踞
+针
+晾
+忒
+扼
+瞩
+叛
+椒
+疟
+嗡
+邗
+肆
+跆
+玫
+忡
+捣
+咧
+唆
+艄
+蘑
+潦
+笛
+阚
+沸
+泻
+掊
+菽
+贫
+斥
+髂
+孢
+镂
+赂
+麝
+鸾
+屡
+衬
+苷
+恪
+叠
+希
+粤
+爻
+喝
+茫
+惬
+郸
+绻
+庸
+撅
+碟
+宄
+妹
+膛
+叮
+饵
+崛
+嗲
+椅
+冤
+搅
+咕
+敛
+尹
+垦
+闷
+蝉
+霎
+勰
+败
+蓑
+泸
+肤
+鹌
+幌
+焦
+浠
+鞍
+刁
+舰
+乙
+竿
+裔
+。
+茵
+函
+伊
+兄
+丨
+娜
+匍
+謇
+莪
+宥
+似
+蝽
+翳
+酪
+翠
+粑
+薇
+祢
+骏
+赠
+叫
+Q
+噤
+噻
+竖
+芗
+莠
+潭
+俊
+羿
+耜
+O
+郫
+趁
+嗪
+囚
+蹶
+芒
+洁
+笋
+鹑
+敲
+硝
+啶
+堡
+渲
+揩
+』
+携
+宿
+遒
+颍
+扭
+棱
+割
+萜
+蔸
+葵
+琴
+捂
+饰
+衙
+耿
+掠
+募
+岂
+窖
+涟
+蔺
+瘤
+柞
+瞪
+怜
+匹
+距
+楔
+炜
+哆
+秦
+缎
+幼
+茁
+绪
+痨
+恨
+楸
+娅
+瓦
+桩
+雪
+嬴
+伏
+榔
+妥
+铿
+拌
+眠
+雍
+缇
+‘
+卓
+搓
+哌
+觞
+噩
+屈
+哧
+髓
+咦
+巅
+娑
+侑
+淫
+膳
+祝
+勾
+姊
+莴
+胄
+疃
+薛
+蜷
+胛
+巷
+芙
+芋
+熙
+闰
+勿
+窃
+狱
+剩
+钏
+幢
+陟
+铛
+慧
+靴
+耍
+k
+浙
+浇
+飨
+惟
+绗
+祜
+澈
+啼
+咪
+磷
+摞
+诅
+郦
+抹
+跃
+壬
+吕
+肖
+琏
+颤
+尴
+剡
+抠
+凋
+赚
+泊
+津
+宕
+殷
+倔
+氲
+漫
+邺
+涎
+怠
+$
+垮
+荬
+遵
+俏
+叹
+噢
+饽
+蜘
+孙
+筵
+疼
+鞭
+羧
+牦
+箭
+潴
+c
+眸
+祭
+髯
+啖
+坳
+愁
+芩
+驮
+倡
+巽
+穰
+沃
+胚
+怒
+凤
+槛
+剂
+趵
+嫁
+v
+邢
+灯
+鄢
+桐
+睽
+檗
+锯
+槟
+婷
+嵋
+圻
+诗
+蕈
+颠
+遭
+痢
+芸
+怯
+馥
+竭
+锗
+徜
+恭
+遍
+籁
+剑
+嘱
+苡
+龄
+僧
+桑
+潸
+弘
+澶
+楹
+悲
+讫
+愤
+腥
+悸
+谍
+椹
+呢
+桓
+葭
+攫
+阀
+翰
+躲
+敖
+柑
+郎
+笨
+橇
+呃
+魁
+燎
+脓
+葩
+磋
+垛
+玺
+狮
+沓
+砜
+蕊
+锺
+罹
+蕉
+翱
+虐
+闾
+巫
+旦
+茱
+嬷
+枯
+鹏
+贡
+芹
+汛
+矫
+绁
+拣
+禺
+佃
+讣
+舫
+惯
+乳
+趋
+疲
+挽
+岚
+虾
+衾
+蠹
+蹂
+飓
+氦
+铖
+孩
+稞
+瑜
+壅
+掀
+勘
+妓
+畅
+髋
+W
+庐
+牲
+蓿
+榕
+练
+垣
+唱
+邸
+菲
+昆
+婺
+穿
+绡
+麒
+蚱
+掂
+愚
+泷
+涪
+漳
+妩
+娉
+榄
+讷
+觅
+旧
+藤
+煮
+呛
+柳
+腓
+叭
+庵
+烷
+阡
+罂
+蜕
+擂
+猖
+咿
+媲
+脉
+【
+沏
+貅
+黠
+熏
+哲
+烁
+坦
+酵
+兜
+×
+潇
+撒
+剽
+珩
+圹
+乾
+摸
+樟
+帽
+嗒
+襄
+魂
+轿
+憬
+锡
+〕
+喃
+皆
+咖
+隅
+脸
+残
+泮
+袂
+鹂
+珊
+囤
+捆
+咤
+误
+徨
+闹
+淙
+芊
+淋
+怆
+囗
+拨
+梳
+渤
+R
+G
+绨
+蚓
+婀
+幡
+狩
+麾
+谢
+唢
+裸
+旌
+伉
+纶
+裂
+驳
+砼
+咛
+澄
+樨
+蹈
+宙
+澍
+倍
+貔
+操
+勇
+蟠
+摈
+砧
+虬
+够
+缁
+悦
+藿
+撸
+艹
+摁
+淹
+豇
+虎
+榭
+ˉ
+吱
+d
+°
+喧
+荀
+踱
+侮
+奋
+偕
+饷
+犍
+惮
+坑
+璎
+徘
+宛
+妆
+袈
+倩
+窦
+昂
+荏
+乖
+K
+怅
+撰
+鳙
+牙
+袁
+酞
+X
+痿
+琼
+闸
+雁
+趾
+荚
+虻
+涝
+《
+杏
+韭
+偈
+烤
+绫
+鞘
+卉
+症
+遢
+蓥
+诋
+杭
+荨
+匆
+竣
+簪
+辙
+敕
+虞
+丹
+缭
+咩
+黟
+m
+淤
+瑕
+咂
+铉
+硼
+茨
+嶂
+痒
+畸
+敬
+涿
+粪
+窘
+熟
+叔
+嫔
+盾
+忱
+裘
+憾
+梵
+赡
+珙
+咯
+娘
+庙
+溯
+胺
+葱
+痪
+摊
+荷
+卞
+乒
+髦
+寐
+铭
+坩
+胗
+枷
+爆
+溟
+嚼
+羚
+砬
+轨
+惊
+挠
+罄
+竽
+菏
+氧
+浅
+楣
+盼
+枢
+炸
+阆
+杯
+谏
+噬
+淇
+渺
+俪
+秆
+墓
+泪
+跻
+砌
+痰
+垡
+渡
+耽
+釜
+讶
+鳎
+煞
+呗
+韶
+舶
+绷
+鹳
+缜
+旷
+铊
+皱
+龌
+檀
+霖
+奄
+槐
+艳
+蝶
+旋
+哝
+赶
+骞
+蚧
+腊
+盈
+丁
+`
+蜚
+矸
+蝙
+睨
+嚓
+僻
+鬼
+醴
+夜
+彝
+磊
+笔
+拔
+栀
+糕
+厦
+邰
+纫
+逭
+纤
+眦
+膊
+馍
+躇
+烯
+蘼
+冬
+诤
+暄
+骶
+哑
+瘠
+」
+臊
+丕
+愈
+咱
+螺
+擅
+跋
+搏
+硪
+谄
+笠
+淡
+嘿
+骅
+谧
+鼎
+皋
+姚
+歼
+蠢
+驼
+耳
+胬
+挝
+涯
+狗
+蒽
+孓
+犷
+凉
+芦
+箴
+铤
+孤
+嘛
+坤
+V
+茴
+朦
+挞
+尖
+橙
+诞
+搴
+碇
+洵
+浚
+帚
+蜍
+漯
+柘
+嚎
+讽
+芭
+荤
+咻
+祠
+秉
+跖
+埃
+吓
+糯
+眷
+馒
+惹
+娼
+鲑
+嫩
+讴
+轮
+瞥
+靶
+褚
+乏
+缤
+宋
+帧
+删
+驱
+碎
+扑
+俩
+俄
+偏
+涣
+竹
+噱
+皙
+佰
+渚
+唧
+斡
+#
+镉
+刀
+崎
+筐
+佣
+夭
+贰
+肴
+峙
+哔
+艿
+匐
+牺
+镛
+缘
+仡
+嫡
+劣
+枸
+堀
+梨
+簿
+鸭
+蒸
+亦
+稽
+浴
+{
+衢
+束
+槲
+j
+阁
+揍
+疥
+棋
+潋
+聪
+窜
+乓
+睛
+插
+冉
+阪
+苍
+搽
+「
+蟾
+螟
+幸
+仇
+樽
+撂
+慢
+跤
+幔
+俚
+淅
+覃
+觊
+溶
+妖
+帛
+侨
+曰
+妾
+泗
+·
+:
+瀘
+風
+Ë
+(
+)
+∶
+紅
+紗
+瑭
+雲
+頭
+鶏
+財
+許
+•
+¥
+樂
+焗
+麗
+—
+;
+滙
+東
+榮
+繪
+興
+…
+門
+業
+π
+楊
+國
+顧
+é
+盤
+寳
+Λ
+龍
+鳳
+島
+誌
+緣
+結
+銭
+萬
+勝
+祎
+璟
+優
+歡
+臨
+時
+購
+=
+★
+藍
+昇
+鐵
+觀
+勅
+農
+聲
+畫
+兿
+術
+發
+劉
+記
+專
+耑
+園
+書
+壴
+種
+Ο
+●
+褀
+號
+銀
+匯
+敟
+锘
+葉
+橪
+廣
+進
+蒄
+鑽
+阝
+祙
+貢
+鍋
+豊
+夬
+喆
+團
+閣
+開
+燁
+賓
+館
+酡
+沔
+順
++
+硚
+劵
+饸
+陽
+車
+湓
+復
+萊
+氣
+軒
+華
+堃
+迮
+纟
+戶
+馬
+學
+裡
+電
+嶽
+獨
+マ
+シ
+サ
+ジ
+燘
+袪
+環
+❤
+臺
+灣
+専
+賣
+孖
+聖
+攝
+線
+▪
+α
+傢
+俬
+夢
+達
+莊
+喬
+貝
+薩
+劍
+羅
+壓
+棛
+饦
+尃
+璈
+囍
+醫
+G
+I
+A
+#
+N
+鷄
+髙
+嬰
+啓
+約
+隹
+潔
+賴
+藝
+~
+寶
+籣
+麺
+
+嶺
+√
+義
+網
+峩
+長
+∧
+魚
+機
+構
+②
+鳯
+偉
+L
+B
+㙟
+畵
+鴿
+'
+詩
+溝
+嚞
+屌
+藔
+佧
+玥
+蘭
+織
+1
+3
+9
+0
+7
+點
+砭
+鴨
+鋪
+銘
+廳
+弍
+‧
+創
+湯
+坶
+℃
+卩
+骝
+&
+烜
+荘
+當
+潤
+扞
+係
+懷
+碶
+钅
+蚨
+讠
+☆
+叢
+爲
+埗
+涫
+塗
+→
+楽
+現
+鯨
+愛
+瑪
+鈺
+忄
+悶
+藥
+飾
+樓
+視
+孬
+ㆍ
+燚
+苪
+師
+①
+丼
+锽
+│
+韓
+標
+è
+兒
+閏
+匋
+張
+漢
+Ü
+髪
+會
+閑
+檔
+習
+裝
+の
+峯
+菘
+輝
+И
+雞
+釣
+億
+浐
+K
+O
+R
+8
+H
+E
+P
+T
+W
+D
+S
+C
+M
+F
+姌
+饹
+»
+晞
+廰
+ä
+嵯
+鷹
+負
+飲
+絲
+冚
+楗
+澤
+綫
+區
+❋
+←
+質
+靑
+揚
+③
+滬
+統
+産
+協
+﹑
+乸
+畐
+經
+運
+際
+洺
+岽
+為
+粵
+諾
+崋
+豐
+碁
+ɔ
+V
+2
+6
+齋
+誠
+訂
+´
+勑
+雙
+陳
+無
+í
+泩
+媄
+夌
+刂
+i
+c
+t
+o
+r
+a
+嘢
+耄
+燴
+暃
+壽
+媽
+靈
+抻
+體
+唻
+É
+冮
+甹
+鎮
+錦
+ʌ
+蜛
+蠄
+尓
+駕
+戀
+飬
+逹
+倫
+貴
+極
+Я
+Й
+寬
+磚
+嶪
+郎
+職
+|
+間
+n
+d
+剎
+伈
+課
+飛
+橋
+瘊
+№
+譜
+骓
+圗
+滘
+縣
+粿
+咅
+養
+濤
+彳
+®
+%
+Ⅱ
+啰
+㴪
+見
+矞
+薬
+糁
+邨
+鲮
+顔
+罱
+З
+選
+話
+贏
+氪
+俵
+競
+瑩
+繡
+枱
+β
+綉
+á
+獅
+爾
+™
+麵
+戋
+淩
+徳
+個
+劇
+場
+務
+簡
+寵
+h
+實
+膠
+轱
+圖
+築
+嘣
+樹
+㸃
+營
+耵
+孫
+饃
+鄺
+飯
+麯
+遠
+輸
+坫
+孃
+乚
+閃
+鏢
+㎡
+題
+廠
+關
+↑
+爺
+將
+軍
+連
+篦
+覌
+參
+箸
+-
+窠
+棽
+寕
+夀
+爰
+歐
+呙
+閥
+頡
+熱
+雎
+垟
+裟
+凬
+勁
+帑
+馕
+夆
+疌
+枼
+馮
+貨
+蒤
+樸
+彧
+旸
+靜
+龢
+暢
+㐱
+鳥
+珺
+鏡
+灡
+爭
+堷
+廚
+Ó
+騰
+診
+┅
+蘇
+褔
+凱
+頂
+豕
+亞
+帥
+嘬
+⊥
+仺
+桖
+複
+饣
+絡
+穂
+顏
+棟
+納
+▏
+濟
+親
+設
+計
+攵
+埌
+烺
+ò
+頤
+燦
+蓮
+撻
+節
+講
+濱
+濃
+娽
+洳
+朿
+燈
+鈴
+護
+膚
+铔
+過
+補
+Z
+U
+5
+4
+坋
+闿
+䖝
+餘
+缐
+铞
+貿
+铪
+桼
+趙
+鍊
+[
+㐂
+垚
+菓
+揸
+捲
+鐘
+滏
+𣇉
+爍
+輪
+燜
+鴻
+鮮
+動
+鹞
+鷗
+丄
+慶
+鉌
+翥
+飮
+腸
+⇋
+漁
+覺
+來
+熘
+昴
+翏
+鲱
+圧
+鄉
+萭
+頔
+爐
+嫚
+г
+貭
+類
+聯
+幛
+輕
+訓
+鑒
+夋
+锨
+芃
+珣
+䝉
+扙
+嵐
+銷
+處
+ㄱ
+語
+誘
+苝
+歸
+儀
+燒
+楿
+內
+粢
+葒
+奧
+麥
+礻
+滿
+蠔
+穵
+瞭
+態
+鱬
+榞
+硂
+鄭
+黃
+煙
+祐
+奓
+逺
+*
+瑄
+獲
+聞
+薦
+讀
+這
+樣
+決
+問
+啟
+們
+執
+説
+轉
+單
+隨
+唘
+帶
+倉
+庫
+還
+贈
+尙
+皺
+■
+餅
+產
+○
+∈
+報
+狀
+楓
+賠
+琯
+嗮
+禮
+`
+傳
+>
+≤
+嗞
+Φ
+≥
+換
+咭
+∣
+↓
+曬
+ε
+応
+寫
+″
+終
+様
+純
+費
+療
+聨
+凍
+壐
+郵
+ü
+黒
+∫
+製
+塊
+調
+軽
+確
+撃
+級
+馴
+Ⅲ
+涇
+繹
+數
+碼
+證
+狒
+処
+劑
+<
+晧
+賀
+衆
+]
+櫥
+兩
+陰
+絶
+對
+鯉
+憶
+◎
+p
+e
+Y
+蕒
+煖
+頓
+測
+試
+鼽
+僑
+碩
+妝
+帯
+≈
+鐡
+舖
+權
+喫
+倆
+ˋ
+該
+悅
+ā
+俫
+.
+f
+s
+b
+m
+k
+g
+u
+j
+貼
+淨
+濕
+針
+適
+備
+l
+/
+給
+謢
+強
+觸
+衛
+與
+⊙
+$
+緯
+變
+⑴
+⑵
+⑶
+㎏
+殺
+∩
+幚
+─
+價
+▲
+離
+ú
+ó
+飄
+烏
+関
+閟
+﹝
+﹞
+邏
+輯
+鍵
+驗
+訣
+導
+歷
+屆
+層
+▼
+儱
+錄
+熳
+ē
+艦
+吋
+錶
+辧
+飼
+顯
+④
+禦
+販
+気
+対
+枰
+閩
+紀
+幹
+瞓
+貊
+淚
+△
+眞
+墊
+Ω
+獻
+褲
+縫
+緑
+亜
+鉅
+餠
+{
+}
+◆
+蘆
+薈
+█
+◇
+溫
+彈
+晳
+粧
+犸
+穩
+訊
+崬
+凖
+熥
+П
+舊
+條
+紋
+圍
+Ⅳ
+筆
+尷
+難
+雜
+錯
+綁
+識
+頰
+鎖
+艶
+□
+殁
+殼
+⑧
+├
+▕
+鵬
+ǐ
+ō
+ǒ
+糝
+綱
+▎
+μ
+盜
+饅
+醬
+籤
+蓋
+釀
+鹽
+據
+à
+ɡ
+辦
+◥
+彐
+┌
+婦
+獸
+鲩
+伱
+ī
+蒟
+蒻
+齊
+袆
+腦
+寧
+凈
+妳
+煥
+詢
+偽
+謹
+啫
+鯽
+騷
+鱸
+損
+傷
+鎻
+髮
+買
+冏
+儥
+両
+﹢
+∞
+載
+喰
+z
+羙
+悵
+燙
+曉
+員
+組
+徹
+艷
+痠
+鋼
+鼙
+縮
+細
+嚒
+爯
+≠
+維
+"
+鱻
+壇
+厍
+帰
+浥
+犇
+薡
+軎
+²
+應
+醜
+刪
+緻
+鶴
+賜
+噁
+軌
+尨
+镔
+鷺
+槗
+彌
+葚
+濛
+請
+溇
+緹
+賢
+訪
+獴
+瑅
+資
+縤
+陣
+蕟
+栢
+韻
+祼
+恁
+伢
+謝
+劃
+涑
+總
+衖
+踺
+砋
+凉
+籃
+駿
+苼
+瘋
+昽
+紡
+驊
+腎
+﹗
+響
+杋
+剛
+嚴
+禪
+歓
+槍
+傘
+檸
+檫
+炣
+勢
+鏜
+鎢
+銑
+尐
+減
+奪
+惡
+θ
+僮
+婭
+臘
+ū
+ì
+殻
+鉄
+∑
+蛲
+焼
+緖
+續
+紹
+懮
\ No newline at end of file
diff --git a/iopaint/model/anytext/utils.py b/iopaint/model/anytext/utils.py
new file mode 100644
index 0000000..c9f55b8
--- /dev/null
+++ b/iopaint/model/anytext/utils.py
@@ -0,0 +1,151 @@
+import os
+import datetime
+import cv2
+import numpy as np
+from PIL import Image, ImageDraw
+
+
+def save_images(img_list, folder):
+ if not os.path.exists(folder):
+ os.makedirs(folder)
+ now = datetime.datetime.now()
+ date_str = now.strftime("%Y-%m-%d")
+ folder_path = os.path.join(folder, date_str)
+ if not os.path.exists(folder_path):
+ os.makedirs(folder_path)
+ time_str = now.strftime("%H_%M_%S")
+ for idx, img in enumerate(img_list):
+ image_number = idx + 1
+ filename = f"{time_str}_{image_number}.jpg"
+ save_path = os.path.join(folder_path, filename)
+ cv2.imwrite(save_path, img[..., ::-1])
+
+
+def check_channels(image):
+ channels = image.shape[2] if len(image.shape) == 3 else 1
+ if channels == 1:
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+ elif channels > 3:
+ image = image[:, :, :3]
+ return image
+
+
+def resize_image(img, max_length=768):
+ height, width = img.shape[:2]
+ max_dimension = max(height, width)
+
+ if max_dimension > max_length:
+ scale_factor = max_length / max_dimension
+ new_width = int(round(width * scale_factor))
+ new_height = int(round(height * scale_factor))
+ new_size = (new_width, new_height)
+ img = cv2.resize(img, new_size)
+ height, width = img.shape[:2]
+ img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
+ return img
+
+
+def insert_spaces(string, nSpace):
+ if nSpace == 0:
+ return string
+ new_string = ""
+ for char in string:
+ new_string += char + " " * nSpace
+ return new_string[:-nSpace]
+
+
+def draw_glyph(font, text):
+ g_size = 50
+ W, H = (512, 80)
+ new_font = font.font_variant(size=g_size)
+ img = Image.new(mode="1", size=(W, H), color=0)
+ draw = ImageDraw.Draw(img)
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = max(right - left, 5)
+ text_height = max(bottom - top, 5)
+ ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
+ new_font = font.font_variant(size=int(g_size * ratio))
+
+ text_width, text_height = new_font.getsize(text)
+ offset_x, offset_y = new_font.getoffset(text)
+ x = (img.width - text_width) // 2
+ y = (img.height - text_height) // 2 - offset_y // 2
+ draw.text((x, y), text, font=new_font, fill="white")
+ img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
+ return img
+
+
+def draw_glyph2(
+ font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True
+):
+ enlarge_polygon = polygon * scale
+ rect = cv2.minAreaRect(enlarge_polygon)
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ w, h = rect[1]
+ angle = rect[2]
+ if angle < -45:
+ angle += 90
+ angle = -angle
+ if w < h:
+ angle += 90
+
+ vert = False
+ if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
+ _w = max(box[:, 0]) - min(box[:, 0])
+ _h = max(box[:, 1]) - min(box[:, 1])
+ if _h >= _w:
+ vert = True
+ angle = 0
+
+ img = np.zeros((height * scale, width * scale, 3), np.uint8)
+ img = Image.fromarray(img)
+
+ # infer font size
+ image4ratio = Image.new("RGB", img.size, "white")
+ draw = ImageDraw.Draw(image4ratio)
+ _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
+ text_w = min(w, h) * (_tw / _th)
+ if text_w <= max(w, h):
+ # add space
+ if len(text) > 1 and not vert and add_space:
+ for i in range(1, 100):
+ text_space = insert_spaces(text, i)
+ _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
+ if min(w, h) * (_tw2 / _th2) > max(w, h):
+ break
+ text = insert_spaces(text, i - 1)
+ font_size = min(w, h) * 0.80
+ else:
+ shrink = 0.75 if vert else 0.85
+ font_size = min(w, h) / (text_w / max(w, h)) * shrink
+ new_font = font.font_variant(size=int(font_size))
+
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = right - left
+ text_height = bottom - top
+
+ layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
+ draw = ImageDraw.Draw(layer)
+ if not vert:
+ draw.text(
+ (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
+ text,
+ font=new_font,
+ fill=(255, 255, 255, 255),
+ )
+ else:
+ x_s = min(box[:, 0]) + _w // 2 - text_height // 2
+ y_s = min(box[:, 1])
+ for c in text:
+ draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
+ _, _t, _, _b = new_font.getbbox(c)
+ y_s += _b
+
+ rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
+
+ x_offset = int((img.width - rotated_layer.width) / 2)
+ y_offset = int((img.height - rotated_layer.height) / 2)
+ img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
+ img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
+ return img