IOPaint/inpaint/model/anytext/cldm/embedding_manager.py
root 70af4845af new file: inpaint/__init__.py
new file:   inpaint/__main__.py
	new file:   inpaint/api.py
	new file:   inpaint/batch_processing.py
	new file:   inpaint/benchmark.py
	new file:   inpaint/cli.py
	new file:   inpaint/const.py
	new file:   inpaint/download.py
	new file:   inpaint/file_manager/__init__.py
	new file:   inpaint/file_manager/file_manager.py
	new file:   inpaint/file_manager/storage_backends.py
	new file:   inpaint/file_manager/utils.py
	new file:   inpaint/helper.py
	new file:   inpaint/installer.py
	new file:   inpaint/model/__init__.py
	new file:   inpaint/model/anytext/__init__.py
	new file:   inpaint/model/anytext/anytext_model.py
	new file:   inpaint/model/anytext/anytext_pipeline.py
	new file:   inpaint/model/anytext/anytext_sd15.yaml
	new file:   inpaint/model/anytext/cldm/__init__.py
	new file:   inpaint/model/anytext/cldm/cldm.py
	new file:   inpaint/model/anytext/cldm/ddim_hacked.py
	new file:   inpaint/model/anytext/cldm/embedding_manager.py
	new file:   inpaint/model/anytext/cldm/hack.py
	new file:   inpaint/model/anytext/cldm/model.py
	new file:   inpaint/model/anytext/cldm/recognizer.py
	new file:   inpaint/model/anytext/ldm/__init__.py
	new file:   inpaint/model/anytext/ldm/models/__init__.py
	new file:   inpaint/model/anytext/ldm/models/autoencoder.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/__init__.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/ddim.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/ddpm.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/plms.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/sampling_util.py
	new file:   inpaint/model/anytext/ldm/modules/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/attention.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/model.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/util.py
	new file:   inpaint/model/anytext/ldm/modules/distributions/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/distributions/distributions.py
	new file:   inpaint/model/anytext/ldm/modules/ema.py
	new file:   inpaint/model/anytext/ldm/modules/encoders/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/encoders/modules.py
	new file:   inpaint/model/anytext/ldm/util.py
	new file:   inpaint/model/anytext/main.py
	new file:   inpaint/model/anytext/ocr_recog/RNN.py
	new file:   inpaint/model/anytext/ocr_recog/RecCTCHead.py
	new file:   inpaint/model/anytext/ocr_recog/RecModel.py
	new file:   inpaint/model/anytext/ocr_recog/RecMv1_enhance.py
	new file:   inpaint/model/anytext/ocr_recog/RecSVTR.py
	new file:   inpaint/model/anytext/ocr_recog/__init__.py
	new file:   inpaint/model/anytext/ocr_recog/common.py
	new file:   inpaint/model/anytext/ocr_recog/en_dict.txt
	new file:   inpaint/model/anytext/ocr_recog/ppocr_keys_v1.txt
	new file:   inpaint/model/anytext/utils.py
	new file:   inpaint/model/base.py
	new file:   inpaint/model/brushnet/__init__.py
	new file:   inpaint/model/brushnet/brushnet.py
	new file:   inpaint/model/brushnet/brushnet_unet_forward.py
	new file:   inpaint/model/brushnet/brushnet_wrapper.py
	new file:   inpaint/model/brushnet/pipeline_brushnet.py
	new file:   inpaint/model/brushnet/unet_2d_blocks.py
	new file:   inpaint/model/controlnet.py
	new file:   inpaint/model/ddim_sampler.py
	new file:   inpaint/model/fcf.py
	new file:   inpaint/model/helper/__init__.py
	new file:   inpaint/model/helper/controlnet_preprocess.py
	new file:   inpaint/model/helper/cpu_text_encoder.py
	new file:   inpaint/model/helper/g_diffuser_bot.py
	new file:   inpaint/model/instruct_pix2pix.py
	new file:   inpaint/model/kandinsky.py
	new file:   inpaint/model/lama.py
	new file:   inpaint/model/ldm.py
	new file:   inpaint/model/manga.py
	new file:   inpaint/model/mat.py
	new file:   inpaint/model/mi_gan.py
	new file:   inpaint/model/opencv2.py
	new file:   inpaint/model/original_sd_configs/__init__.py
	new file:   inpaint/model/original_sd_configs/sd_xl_base.yaml
	new file:   inpaint/model/original_sd_configs/sd_xl_refiner.yaml
	new file:   inpaint/model/original_sd_configs/v1-inference.yaml
	new file:   inpaint/model/original_sd_configs/v2-inference-v.yaml
	new file:   inpaint/model/paint_by_example.py
	new file:   inpaint/model/plms_sampler.py
	new file:   inpaint/model/power_paint/__init__.py
	new file:   inpaint/model/power_paint/pipeline_powerpaint.py
	new file:   inpaint/model/power_paint/power_paint.py
	new file:   inpaint/model/power_paint/power_paint_v2.py
	new file:   inpaint/model/power_paint/powerpaint_tokenizer.py
2024-08-20 21:17:33 +02:00

166 lines
5.9 KiB
Python

'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1]
return token
def get_clip_vision_emb(encoder, processor, img):
_img = img.repeat(1, 3, 1, 1)*255
inputs = processor(images=_img, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
outputs = encoder(**inputs)
emb = outputs.image_embeds
return emb
def get_recog_emb(encoder, img_list):
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
encoder.predictor.eval()
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
return preds_neck
def pad_H(x):
_, _, H, W = x.shape
p_top = (W - H) // 2
p_bot = W - H - p_top
return F.pad(x, (0, 0, p_top, p_bot))
class EncodeNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeNet, self).__init__()
chan = 16
n_layer = 4 # downsample
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
self.conv_list = nn.ModuleList([])
_c = chan
for i in range(n_layer):
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
_c *= 2
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.act = nn.SiLU()
def forward(self, x):
x = self.act(self.conv1(x))
for layer in self.conv_list:
x = self.act(layer(x))
x = self.act(self.conv2(x))
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
valid=True,
glyph_channels=20,
position_channels=1,
placeholder_string='*',
add_pos=False,
emb_type='ocr',
**kwargs
):
super().__init__()
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
token_dim = 768
if hasattr(embedder, 'vit'):
assert emb_type == 'vit'
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
self.get_recog_emb = None
else: # using LDM's BERT encoder
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
token_dim = 1280
self.token_dim = token_dim
self.emb_type = emb_type
self.add_pos = add_pos
if add_pos:
self.position_encoder = EncodeNet(position_channels, token_dim)
if emb_type == 'ocr':
self.proj = linear(40*64, token_dim)
if emb_type == 'conv':
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
self.placeholder_token = get_token_for_string(placeholder_string)
def encode_text(self, text_info):
if self.get_recog_emb is None and self.emb_type == 'ocr':
self.get_recog_emb = partial(get_recog_emb, self.recog)
gline_list = []
pos_list = []
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
for j in range(n_lines): # line
gline_list += [text_info['gly_line'][j][i:i+1]]
if self.add_pos:
pos_list += [text_info['positions'][j][i:i+1]]
if len(gline_list) > 0:
if self.emb_type == 'ocr':
recog_emb = self.get_recog_emb(gline_list)
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
elif self.emb_type == 'vit':
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
elif self.emb_type == 'conv':
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
if self.add_pos:
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
enc_glyph = enc_glyph+enc_pos
self.text_embs_all = []
n_idx = 0
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
text_embs = []
for j in range(n_lines): # line
text_embs += [enc_glyph[n_idx:n_idx+1]]
n_idx += 1
self.text_embs_all += [text_embs]
def forward(
self,
tokenized_text,
embedded_text,
):
b, device = tokenized_text.shape[0], tokenized_text.device
for i in range(b):
idx = tokenized_text[i] == self.placeholder_token.to(device)
if sum(idx) > 0:
if i >= len(self.text_embs_all):
print('truncation for log images...')
break
text_emb = torch.cat(self.text_embs_all[i], dim=0)
if sum(idx) != len(text_emb):
print('truncation for long caption...')
embedded_text[i][idx] = text_emb[:sum(idx)]
return embedded_text
def embedding_parameters(self):
return self.parameters()