add AnyText
This commit is contained in:
parent
f5bd697687
commit
1905743886
72
README.md
72
README.md
@ -1 +1,71 @@
|
||||
# IOPaint
|
||||
<h1 align="center">IOPaint</h1>
|
||||
<p align="center">A free and open-source inpainting & outpainting tool powered by SOTA AI model.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/Sanster/IOPaint">
|
||||
<img alt="total download" src="https://pepy.tech/badge/iopaint" />
|
||||
</a>
|
||||
<a href="https://pypi.org/project/iopaint">
|
||||
<img alt="version" src="https://img.shields.io/pypi/v/iopaint" />
|
||||
</a>
|
||||
<a href="">
|
||||
<img alt="python version" src="https://img.shields.io/pypi/pyversions/iopaint" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
## Quick Start
|
||||
|
||||
IOPaint provides an easy-to-use webui for utilizing the latest AI models. The installation process for IOPaint is also simple, requiring just two commands:
|
||||
|
||||
```bash
|
||||
# In order to use GPU, install cuda version of pytorch first.
|
||||
# pip3 install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip3 install iopaint
|
||||
iopaint start --model=lama --device=cpu --port=8080
|
||||
```
|
||||
|
||||
That's it, you can start using IOPaint by visiting http://localhost:8080 in your web browser.
|
||||
|
||||
You can also use IOPaint in the command line to batch process images:
|
||||
|
||||
```bash
|
||||
iopaint run --model=lama --device=cpu \
|
||||
--input=/path/to/image_folder \
|
||||
--mask=/path/to/mask_folder \
|
||||
--output=output_dir
|
||||
```
|
||||
|
||||
`--input` is the folder containing input images, `--mask` is the folder containing corresponding mask images.
|
||||
When `--mask` is a path to a mask file, all images will be processed using this mask.
|
||||
|
||||
You can see more information about the models and plugins supported by IOPaint below.
|
||||
|
||||
## Features
|
||||
|
||||
- Completely free and open-source, fully self-hosted, support CPU & GPU & M1/2
|
||||
- Supports various AI models:
|
||||
- Inpainting models: These models are usually used to remove people or objects from images.
|
||||
- Stable Diffusion models: These models have stronger generation abilities, allowing them to generate new objects on images, or to expand existing images.
|
||||
You can use any Stable Diffusion Inpainting(or normal) models from [Huggingface](https://huggingface.co/models?other=stable-diffusion) in IOPaint.
|
||||
Some commonly used models are listed below:
|
||||
- [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
|
||||
- [diffusers/stable-diffusion-xl-1.0-inpainting-0.1](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1)
|
||||
- [andregn/Realistic_Vision_V3.0-inpainting](https://huggingface.co/andregn/Realistic_Vision_V3.0-inpainting)
|
||||
- [Lykon/dreamshaper-8-inpainting](https://huggingface.co/Lykon/dreamshaper-8-inpainting)
|
||||
- [Sanster/anything-4.0-inpainting](https://huggingface.co/Sanster/anything-4.0-inpainting)
|
||||
- [Sanster/PowerPaint-V1-stable-diffusion-inpainting](https://huggingface.co/Sanster/PowerPaint-V1-stable-diffusion-inpainting)
|
||||
- Other Diffusion models:
|
||||
- [Sanster/AnyText](https://huggingface.co/Sanster/AnyText): Generate text on images
|
||||
- [timbrooks/instruct-pix2pix](https://huggingface.co/timbrooks/instruct-pix2pix)
|
||||
- [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example): Generate images from text
|
||||
- [kandinsky-community/kandinsky-2-2-decoder-inpaint](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint)
|
||||
- [Plugins](https://iopaint.com/plugins) for post-processing:
|
||||
- [Segment Anything](https://iopaint.com/plugins/interactive_seg): Accurate and fast interactive object segmentation
|
||||
- [RemoveBG](https://iopaint.com/plugins/rembg): Remove image background or generate masks for foreground objects
|
||||
- [Anime Segmentation](https://iopaint.com/plugins/anime_seg): Similar to RemoveBG, the model is specifically trained for anime images.
|
||||
- [RealESRGAN](https://iopaint.com/plugins/RealESRGAN): Super Resolution
|
||||
- [GFPGAN](https://iopaint.com/plugins/GFPGAN): Face Restoration
|
||||
- [RestoreFormer](https://iopaint.com/plugins/RestoreFormer): Face Restoration
|
||||
- [FileManager](https://iopaint.com/features/file_manager): Browse your pictures conveniently and save them directly to the output directory.
|
||||
- [Native macOS app](https://opticlean.io/) for erase task
|
||||
- More features at [IOPaint Docs](https://iopaint.com/)
|
||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
||||
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||
ANYTEXT_NAME = "Sanster/AnyText"
|
||||
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
|
@ -12,6 +12,7 @@ from iopaint.const import (
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.model_info import ModelInfo, ModelType
|
||||
|
||||
@ -24,6 +25,10 @@ def cli_download_model(model: str):
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
elif model == ANYTEXT_NAME:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
else:
|
||||
logger.info(f"Downloading model from Huggingface: {model}")
|
||||
from diffusers import DiffusionPipeline
|
||||
@ -210,6 +215,7 @@ def scan_models() -> List[ModelInfo]:
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"AnyText",
|
||||
]:
|
||||
model_type = ModelType.DIFFUSERS_OTHER
|
||||
else:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .anytext.anytext_model import AnyText
|
||||
from .controlnet import ControlNet
|
||||
from .fcf import FcF
|
||||
from .instruct_pix2pix import InstructPix2Pix
|
||||
@ -32,4 +33,5 @@ models = {
|
||||
Kandinsky22.name: Kandinsky22,
|
||||
SDXL.name: SDXL,
|
||||
PowerPaint.name: PowerPaint,
|
||||
AnyText.name: AnyText,
|
||||
}
|
||||
|
@ -0,0 +1,73 @@
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from iopaint.const import ANYTEXT_NAME
|
||||
from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
|
||||
from iopaint.model.base import DiffusionInpaintModel
|
||||
from iopaint.model.utils import get_torch_dtype, is_local_files_only
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
|
||||
class AnyText(DiffusionInpaintModel):
|
||||
name = ANYTEXT_NAME
|
||||
pad_mod = 64
|
||||
is_erase_model = False
|
||||
|
||||
@staticmethod
|
||||
def download(local_files_only=False):
|
||||
hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="model_index.json",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
ckpt_path = hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="pytorch_model.fp16.safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
font_path = hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="SourceHanSansSC-Medium.otf",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return ckpt_path, font_path
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
local_files_only = is_local_files_only(**kwargs)
|
||||
ckpt_path, font_path = self.download(local_files_only)
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.model = AnyTextPipeline(
|
||||
ckpt_path=ckpt_path,
|
||||
font_path=font_path,
|
||||
device=device,
|
||||
use_fp16=torch_dtype == torch.float16,
|
||||
)
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to inpainting
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
mask = mask.astype("float32") / 255.0
|
||||
masked_image = image * (1 - mask)
|
||||
|
||||
# list of rgb ndarray
|
||||
results, rtn_code, rtn_warning = self.model(
|
||||
image=image,
|
||||
masked_image=masked_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
height=height,
|
||||
width=width,
|
||||
seed=config.sd_seed,
|
||||
sort_priority="y",
|
||||
callback=self.callback
|
||||
)
|
||||
inpainted_rgb_image = results[0][..., ::-1]
|
||||
return inpainted_rgb_image
|
@ -5,20 +5,21 @@ Code: https://github.com/tyxsspa/AnyText
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.model.utils import set_seed
|
||||
from safetensors.torch import load_file
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import einops
|
||||
import time
|
||||
from PIL import ImageFont
|
||||
from iopaint.model.anytext.cldm.model import create_model, load_state_dict
|
||||
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
|
||||
from iopaint.model.anytext.utils import (
|
||||
resize_image,
|
||||
check_channels,
|
||||
draw_glyph,
|
||||
draw_glyph2,
|
||||
@ -29,55 +30,93 @@ BBOX_MAX_NUM = 8
|
||||
PLACE_HOLDER = "*"
|
||||
max_chars = 20
|
||||
|
||||
ANYTEXT_CFG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
|
||||
)
|
||||
|
||||
|
||||
def check_limits(tensor):
|
||||
float16_min = torch.finfo(torch.float16).min
|
||||
float16_max = torch.finfo(torch.float16).max
|
||||
|
||||
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
|
||||
is_below_min = (tensor < float16_min).any()
|
||||
is_above_max = (tensor > float16_max).any()
|
||||
|
||||
return is_below_min or is_above_max
|
||||
|
||||
|
||||
class AnyTextPipeline:
|
||||
def __init__(self, cfg_path, model_dir, font_path, device, use_fp16=True):
|
||||
self.cfg_path = cfg_path
|
||||
self.model_dir = model_dir
|
||||
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
|
||||
self.cfg_path = ANYTEXT_CFG
|
||||
self.font_path = font_path
|
||||
self.use_fp16 = use_fp16
|
||||
self.device = device
|
||||
self.init_model()
|
||||
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
self.model = create_model(
|
||||
self.cfg_path,
|
||||
device=self.device,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
if Path(ckpt_path).suffix == ".safetensors":
|
||||
state_dict = load_file(ckpt_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_state_dict(ckpt_path, location="cpu")
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model = self.model.eval().to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: np.ndarray,
|
||||
masked_image: np.ndarray,
|
||||
num_inference_steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
sort_priority: str = "y",
|
||||
callback=None,
|
||||
):
|
||||
"""
|
||||
return:
|
||||
|
||||
Args:
|
||||
prompt:
|
||||
negative_prompt:
|
||||
image:
|
||||
masked_image:
|
||||
num_inference_steps:
|
||||
strength:
|
||||
guidance_scale:
|
||||
height:
|
||||
width:
|
||||
seed:
|
||||
sort_priority: x: left-right, y: top-down
|
||||
|
||||
Returns:
|
||||
result: list of images in numpy.ndarray format
|
||||
rst_code: 0: normal -1: error 1:warning
|
||||
rst_info: string of error or warning
|
||||
debug_info: string for debug, only valid if show_debug=True
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, input_tensor, **forward_params):
|
||||
tic = time.time()
|
||||
set_seed(seed)
|
||||
str_warning = ""
|
||||
# get inputs
|
||||
seed = input_tensor.get("seed", -1)
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999999)
|
||||
# seed_everything(seed)
|
||||
prompt = input_tensor.get("prompt")
|
||||
draw_pos = input_tensor.get("draw_pos")
|
||||
ori_image = input_tensor.get("ori_image")
|
||||
|
||||
mode = forward_params.get("mode")
|
||||
sort_priority = forward_params.get("sort_priority", "↕")
|
||||
show_debug = forward_params.get("show_debug", False)
|
||||
revise_pos = forward_params.get("revise_pos", False)
|
||||
img_count = forward_params.get("image_count", 4)
|
||||
ddim_steps = forward_params.get("ddim_steps", 20)
|
||||
w = forward_params.get("image_width", 512)
|
||||
h = forward_params.get("image_height", 512)
|
||||
strength = forward_params.get("strength", 1.0)
|
||||
cfg_scale = forward_params.get("cfg_scale", 9.0)
|
||||
eta = forward_params.get("eta", 0.0)
|
||||
a_prompt = forward_params.get(
|
||||
"a_prompt",
|
||||
"best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks",
|
||||
)
|
||||
n_prompt = forward_params.get(
|
||||
"n_prompt",
|
||||
"low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
)
|
||||
mode = "text-editing"
|
||||
revise_pos = False
|
||||
img_count = 1
|
||||
ddim_steps = num_inference_steps
|
||||
w = width
|
||||
h = height
|
||||
strength = strength
|
||||
cfg_scale = guidance_scale
|
||||
eta = 0.0
|
||||
|
||||
prompt, texts = self.modify_prompt(prompt)
|
||||
if prompt is None and texts is None:
|
||||
@ -91,43 +130,44 @@ class AnyTextPipeline:
|
||||
if mode in ["text-generation", "gen"]:
|
||||
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
|
||||
elif mode in ["text-editing", "edit"]:
|
||||
if draw_pos is None or ori_image is None:
|
||||
if masked_image is None or image is None:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
"Reference image and position image are needed for text editing!",
|
||||
"",
|
||||
)
|
||||
if isinstance(ori_image, str):
|
||||
ori_image = cv2.imread(ori_image)[..., ::-1]
|
||||
assert (
|
||||
ori_image is not None
|
||||
), f"Can't read ori_image image from{ori_image}!"
|
||||
elif isinstance(ori_image, torch.Tensor):
|
||||
ori_image = ori_image.cpu().numpy()
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)[..., ::-1]
|
||||
assert image is not None, f"Can't read ori_image image from{image}!"
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
ori_image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(ori_image)}"
|
||||
edit_image = ori_image.clip(1, 255) # for mask reason
|
||||
image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(image)}"
|
||||
edit_image = image.clip(1, 255) # for mask reason
|
||||
edit_image = check_channels(edit_image)
|
||||
edit_image = resize_image(
|
||||
edit_image, max_length=768
|
||||
) # make w h multiple of 64, resize if w or h > max_length
|
||||
# edit_image = resize_image(
|
||||
# edit_image, max_length=768
|
||||
# ) # make w h multiple of 64, resize if w or h > max_length
|
||||
h, w = edit_image.shape[:2] # change h, w by input ref_img
|
||||
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
||||
if draw_pos is None:
|
||||
if masked_image is None:
|
||||
pos_imgs = np.zeros((w, h, 1))
|
||||
if isinstance(draw_pos, str):
|
||||
draw_pos = cv2.imread(draw_pos)[..., ::-1]
|
||||
assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!"
|
||||
pos_imgs = 255 - draw_pos
|
||||
elif isinstance(draw_pos, torch.Tensor):
|
||||
pos_imgs = draw_pos.cpu().numpy()
|
||||
if isinstance(masked_image, str):
|
||||
masked_image = cv2.imread(masked_image)[..., ::-1]
|
||||
assert (
|
||||
masked_image is not None
|
||||
), f"Can't read draw_pos image from{masked_image}!"
|
||||
pos_imgs = 255 - masked_image
|
||||
elif isinstance(masked_image, torch.Tensor):
|
||||
pos_imgs = masked_image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
draw_pos, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(draw_pos)}"
|
||||
masked_image, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(masked_image)}"
|
||||
pos_imgs = 255 - masked_image
|
||||
pos_imgs = pos_imgs[..., 0:1]
|
||||
pos_imgs = cv2.convertScaleAbs(pos_imgs)
|
||||
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
|
||||
@ -139,11 +179,8 @@ class AnyTextPipeline:
|
||||
if n_lines == 1 and texts[0] == " ":
|
||||
pass # text-to-image without text
|
||||
else:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!",
|
||||
"",
|
||||
raise RuntimeError(
|
||||
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
|
||||
)
|
||||
elif len(pos_imgs) > n_lines:
|
||||
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
|
||||
@ -250,12 +287,16 @@ class AnyTextPipeline:
|
||||
cond = self.model.get_learned_conditioning(
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[prompt + " , " + a_prompt] * img_count],
|
||||
c_crossattn=[[prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
un_cond = self.model.get_learned_conditioning(
|
||||
dict(c_concat=[hint], c_crossattn=[[n_prompt] * img_count], text_info=info)
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[negative_prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
shape = (4, h // 8, w // 8)
|
||||
self.model.control_scales = [strength] * 13
|
||||
@ -268,6 +309,7 @@ class AnyTextPipeline:
|
||||
eta=eta,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=un_cond,
|
||||
callback=callback
|
||||
)
|
||||
if self.use_fp16:
|
||||
samples = samples.half()
|
||||
@ -280,52 +322,18 @@ class AnyTextPipeline:
|
||||
.astype(np.uint8)
|
||||
)
|
||||
results = [x_samples[i] for i in range(img_count)]
|
||||
if (
|
||||
mode == "edit" and False
|
||||
): # replace backgound in text editing but not ideal yet
|
||||
results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
if len(gly_pos_imgs) > 0 and show_debug:
|
||||
glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
# debug_info
|
||||
if not show_debug:
|
||||
debug_info = ""
|
||||
else:
|
||||
input_prompt = prompt
|
||||
for t in texts:
|
||||
input_prompt = input_prompt.replace("*", f'"{t}"', 1)
|
||||
debug_info = f'<span style="color:black;font-size:18px">Prompt: </span>{input_prompt}<br> \
|
||||
<span style="color:black;font-size:18px">Size: </span>{w}x{h}<br> \
|
||||
<span style="color:black;font-size:18px">Image Count: </span>{img_count}<br> \
|
||||
<span style="color:black;font-size:18px">Seed: </span>{seed}<br> \
|
||||
<span style="color:black;font-size:18px">Use FP16: </span>{self.use_fp16}<br> \
|
||||
<span style="color:black;font-size:18px">Cost Time: </span>{(time.time()-tic):.2f}s'
|
||||
# if (
|
||||
# mode == "edit" and False
|
||||
# ): # replace backgound in text editing but not ideal yet
|
||||
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
# if len(gly_pos_imgs) > 0 and show_debug:
|
||||
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
# glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
# results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
rst_code = 1 if str_warning else 0
|
||||
return results, rst_code, str_warning, debug_info
|
||||
|
||||
def init_model(self):
|
||||
font_path = self.font_path
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
cfg_path = self.cfg_path
|
||||
ckpt_path = os.path.join(self.model_dir, "anytext_v1.1.ckpt")
|
||||
clip_path = os.path.join(self.model_dir, "clip-vit-large-patch14")
|
||||
self.model = create_model(
|
||||
cfg_path,
|
||||
device=self.device,
|
||||
cond_stage_path=clip_path,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
self.model.load_state_dict(
|
||||
load_state_dict(ckpt_path, location=self.device), strict=False
|
||||
)
|
||||
self.model.eval()
|
||||
self.model = self.model.to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
return results, rst_code, str_warning
|
||||
|
||||
def modify_prompt(self, prompt):
|
||||
prompt = prompt.replace("“", '"')
|
||||
@ -360,9 +368,9 @@ class AnyTextPipeline:
|
||||
component = np.zeros_like(img)
|
||||
component[labels == label] = 255
|
||||
components.append((component, centroids[label]))
|
||||
if sort_priority == "↕":
|
||||
if sort_priority == "y":
|
||||
fir, sec = 1, 0 # top-down first
|
||||
elif sort_priority == "↔":
|
||||
elif sort_priority == "x":
|
||||
fir, sec = 0, 1 # left-right first
|
||||
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
|
||||
sorted_components = [c[0] for c in components]
|
||||
|
@ -95,5 +95,5 @@ model:
|
||||
cond_stage_config:
|
||||
target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
|
||||
params:
|
||||
version: ./models/clip-vit-large-patch14
|
||||
version: openai/clip-vit-large-patch14
|
||||
use_vision: false # v6
|
||||
|
@ -254,7 +254,7 @@ class DDIMSampler(object):
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
callback(None, i, None, None)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
|
@ -26,11 +26,11 @@ def load_state_dict(ckpt_path, location="cpu"):
|
||||
|
||||
def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
|
||||
config = OmegaConf.load(config_path)
|
||||
if cond_stage_path:
|
||||
config.model.params.cond_stage_config.params.version = (
|
||||
cond_stage_path # use pre-downloaded ckpts, in case blocked
|
||||
)
|
||||
config.model.params.cond_stage_config.params.device = device
|
||||
# if cond_stage_path:
|
||||
# config.model.params.cond_stage_config.params.version = (
|
||||
# cond_stage_path # use pre-downloaded ckpts, in case blocked
|
||||
# )
|
||||
config.model.params.cond_stage_config.params.device = str(device)
|
||||
if use_fp16:
|
||||
config.model.params.use_fp16 = True
|
||||
config.model.params.control_stage_config.params.use_fp16 = True
|
||||
|
@ -2,7 +2,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import (
|
||||
T5Tokenizer,
|
||||
T5EncoderModel,
|
||||
CLIPTokenizer,
|
||||
CLIPTextModel,
|
||||
AutoProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from iopaint.model.anytext.ldm.util import count_params
|
||||
|
||||
@ -18,7 +25,9 @@ def _expand_mask(mask, dtype, tgt_len=None):
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||
@ -30,6 +39,7 @@ def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -39,13 +49,12 @@ class AbstractEncoder(nn.Module):
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
@ -57,15 +66,17 @@ class ClassEmbedder(nn.Module):
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||
if self.ucg_rate > 0.0 and not disable_dropout:
|
||||
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc_class = (
|
||||
self.n_classes - 1
|
||||
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs,), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
@ -79,7 +90,10 @@ def disabled_train(self, mode=True):
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
|
||||
def __init__(
|
||||
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
@ -90,13 +104,20 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
@ -109,13 +130,18 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
@ -137,10 +163,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(
|
||||
input_ids=tokens, output_hidden_states=self.layer == "hidden"
|
||||
)
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
@ -153,77 +188,24 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
def __init__(
|
||||
self,
|
||||
clip_version="openai/clip-vit-large-patch14",
|
||||
t5_version="google/t5-v1_1-xl",
|
||||
device="cuda",
|
||||
clip_max_length=77,
|
||||
t5_max_length=77,
|
||||
):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
self.clip_encoder = FrozenCLIPEmbedder(
|
||||
clip_version, device, max_length=clip_max_length
|
||||
)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
@ -236,7 +218,15 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
use_vision=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
@ -255,7 +245,11 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
inputs_embeds=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
seq_length = (
|
||||
input_ids.shape[-1]
|
||||
if input_ids is not None
|
||||
else inputs_embeds.shape[-2]
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
if inputs_embeds is None:
|
||||
@ -266,7 +260,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
return embeddings
|
||||
|
||||
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
|
||||
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
|
||||
self.transformer.text_model.embeddings
|
||||
)
|
||||
|
||||
def encoder_forward(
|
||||
self,
|
||||
@ -277,11 +273,19 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
hidden_states = inputs_embeds
|
||||
@ -301,7 +305,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
return hidden_states
|
||||
|
||||
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
|
||||
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
|
||||
self.transformer.text_model.encoder
|
||||
)
|
||||
|
||||
def text_encoder_forward(
|
||||
self,
|
||||
@ -313,22 +319,34 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
return_dict=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
bsz, seq_len = input_shape
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||
hidden_states.device
|
||||
)
|
||||
causal_attention_mask = _build_causal_attention_mask(
|
||||
bsz, seq_len, hidden_states.dtype
|
||||
).to(hidden_states.device)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
@ -344,7 +362,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
return last_hidden_state
|
||||
|
||||
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
|
||||
self.transformer.text_model.forward = text_encoder_forward.__get__(
|
||||
self.transformer.text_model
|
||||
)
|
||||
|
||||
def transformer_forward(
|
||||
self,
|
||||
@ -363,7 +383,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
embedding_manager=embedding_manager
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
self.transformer.forward = transformer_forward.__get__(self.transformer)
|
||||
@ -374,8 +394,15 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
return z
|
||||
|
@ -1,3 +1,6 @@
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from anytext_pipeline import AnyTextPipeline
|
||||
from utils import save_images
|
||||
|
||||
@ -5,48 +8,38 @@ seed = 66273235
|
||||
# seed_everything(seed)
|
||||
|
||||
pipe = AnyTextPipeline(
|
||||
cfg_path="/Users/cwq/code/github/AnyText/anytext/models_yaMl/anytext_sd15.yaml",
|
||||
model_dir="/Users/cwq/.cache/modelscope/hub/damo/cv_anytext_text_generation_editing",
|
||||
# font_path="/Users/cwq/code/github/AnyText/anytext/font/Arial_Unicode.ttf",
|
||||
# font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-VF.ttf",
|
||||
ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt",
|
||||
font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
|
||||
use_fp16=False,
|
||||
device="mps",
|
||||
)
|
||||
|
||||
img_save_folder = "SaveImages"
|
||||
params = {
|
||||
"show_debug": True,
|
||||
"image_count": 2,
|
||||
"ddim_steps": 20,
|
||||
}
|
||||
rgb_image = cv2.imread(
|
||||
"/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg"
|
||||
)[..., ::-1]
|
||||
|
||||
# # 1. text generation
|
||||
# mode = "text-generation"
|
||||
# input_data = {
|
||||
# "prompt": 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream',
|
||||
# "seed": seed,
|
||||
# "draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/gen9.png",
|
||||
# }
|
||||
# results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
|
||||
# if rtn_code >= 0:
|
||||
# save_images(results, img_save_folder)
|
||||
# print(f"Done, result images are saved in: {img_save_folder}")
|
||||
# if rtn_warning:
|
||||
# print(rtn_warning)
|
||||
#
|
||||
# exit()
|
||||
# 2. text editing
|
||||
mode = "text-editing"
|
||||
input_data = {
|
||||
"prompt": 'A cake with colorful characters that reads "EVERYDAY"',
|
||||
"seed": seed,
|
||||
"draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png",
|
||||
"ori_image": "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg",
|
||||
}
|
||||
results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
|
||||
masked_image = cv2.imread(
|
||||
"/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png"
|
||||
)[..., ::-1]
|
||||
|
||||
rgb_image = cv2.resize(rgb_image, (512, 512))
|
||||
masked_image = cv2.resize(masked_image, (512, 512))
|
||||
|
||||
# results: list of rgb ndarray
|
||||
results, rtn_code, rtn_warning = pipe(
|
||||
prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
|
||||
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
image=rgb_image,
|
||||
masked_image=masked_image,
|
||||
num_inference_steps=20,
|
||||
strength=1.0,
|
||||
guidance_scale=9.0,
|
||||
height=rgb_image.shape[0],
|
||||
width=rgb_image.shape[1],
|
||||
seed=seed,
|
||||
sort_priority="y",
|
||||
)
|
||||
if rtn_code >= 0:
|
||||
save_images(results, img_save_folder)
|
||||
print(f"Done, result images are saved in: {img_save_folder}")
|
||||
if rtn_warning:
|
||||
print(rtn_warning)
|
||||
|
@ -9,6 +9,7 @@ from iopaint.const import (
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.schema import ModelType
|
||||
|
||||
@ -31,6 +32,7 @@ class ModelInfo(BaseModel):
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@ -58,7 +60,7 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [POWERPAINT_NAME]
|
||||
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
BIN
iopaint/tests/anytext_mask.jpg
Normal file
BIN
iopaint/tests/anytext_mask.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.7 KiB |
BIN
iopaint/tests/anytext_ref.jpg
Normal file
BIN
iopaint/tests/anytext_ref.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 104 KiB |
45
iopaint/tests/test_anytext.py
Normal file
45
iopaint/tests/test_anytext.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
def test_anytext(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="Sanster/AnyText",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt='Characters written in chalk on the blackboard that says "DADDY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
|
||||
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=9.0,
|
||||
sd_seed=66273235,
|
||||
sd_match_histograms=True
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"anytext.png",
|
||||
img_p=current_dir / "anytext_ref.jpg",
|
||||
mask_p=current_dir / "anytext_mask.jpg",
|
||||
)
|
@ -16,7 +16,12 @@ import { Separator } from "../ui/separator"
|
||||
import { Button, ImageUploadButton } from "../ui/button"
|
||||
import { Slider } from "../ui/slider"
|
||||
import { useImage } from "@/hooks/useImage"
|
||||
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const"
|
||||
import {
|
||||
ANYTEXT,
|
||||
INSTRUCT_PIX2PIX,
|
||||
PAINT_BY_EXAMPLE,
|
||||
POWERPAINT,
|
||||
} from "@/lib/const"
|
||||
import { RowContainer, LabelTitle } from "./LabelTitle"
|
||||
import { Minus, Plus, Upload } from "lucide-react"
|
||||
import { useClickAway } from "react-use"
|
||||
@ -661,6 +666,10 @@ const DiffusionOptions = () => {
|
||||
}
|
||||
|
||||
const renderSampler = () => {
|
||||
if (settings.model.name === ANYTEXT) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<RowContainer>
|
||||
<LabelTitle text="Sampler" />
|
||||
|
@ -15,6 +15,7 @@ export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example"
|
||||
export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
|
||||
export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
export const POWERPAINT = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||
export const ANYTEXT = "Sanster/AnyText"
|
||||
|
||||
export const DEFAULT_NEGATIVE_PROMPT =
|
||||
"out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature"
|
||||
|
Loading…
Reference in New Issue
Block a user