add AnyText

This commit is contained in:
Qing 2024-01-21 23:25:50 +08:00
parent f5bd697687
commit 1905743886
17 changed files with 513 additions and 276 deletions

View File

@ -1 +1,71 @@
# IOPaint <h1 align="center">IOPaint</h1>
<p align="center">A free and open-source inpainting & outpainting tool powered by SOTA AI model.</p>
<p align="center">
<a href="https://github.com/Sanster/IOPaint">
<img alt="total download" src="https://pepy.tech/badge/iopaint" />
</a>
<a href="https://pypi.org/project/iopaint">
<img alt="version" src="https://img.shields.io/pypi/v/iopaint" />
</a>
<a href="">
<img alt="python version" src="https://img.shields.io/pypi/pyversions/iopaint" />
</a>
</p>
## Quick Start
IOPaint provides an easy-to-use webui for utilizing the latest AI models. The installation process for IOPaint is also simple, requiring just two commands:
```bash
# In order to use GPU, install cuda version of pytorch first.
# pip3 install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip3 install iopaint
iopaint start --model=lama --device=cpu --port=8080
```
That's it, you can start using IOPaint by visiting http://localhost:8080 in your web browser.
You can also use IOPaint in the command line to batch process images:
```bash
iopaint run --model=lama --device=cpu \
--input=/path/to/image_folder \
--mask=/path/to/mask_folder \
--output=output_dir
```
`--input` is the folder containing input images, `--mask` is the folder containing corresponding mask images.
When `--mask` is a path to a mask file, all images will be processed using this mask.
You can see more information about the models and plugins supported by IOPaint below.
## Features
- Completely free and open-source, fully self-hosted, support CPU & GPU & M1/2
- Supports various AI models:
- Inpainting models: These models are usually used to remove people or objects from images.
- Stable Diffusion models: These models have stronger generation abilities, allowing them to generate new objects on images, or to expand existing images.
You can use any Stable Diffusion Inpainting(or normal) models from [Huggingface](https://huggingface.co/models?other=stable-diffusion) in IOPaint.
Some commonly used models are listed below:
- [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
- [diffusers/stable-diffusion-xl-1.0-inpainting-0.1](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1)
- [andregn/Realistic_Vision_V3.0-inpainting](https://huggingface.co/andregn/Realistic_Vision_V3.0-inpainting)
- [Lykon/dreamshaper-8-inpainting](https://huggingface.co/Lykon/dreamshaper-8-inpainting)
- [Sanster/anything-4.0-inpainting](https://huggingface.co/Sanster/anything-4.0-inpainting)
- [Sanster/PowerPaint-V1-stable-diffusion-inpainting](https://huggingface.co/Sanster/PowerPaint-V1-stable-diffusion-inpainting)
- Other Diffusion models:
- [Sanster/AnyText](https://huggingface.co/Sanster/AnyText): Generate text on images
- [timbrooks/instruct-pix2pix](https://huggingface.co/timbrooks/instruct-pix2pix)
- [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example): Generate images from text
- [kandinsky-community/kandinsky-2-2-decoder-inpaint](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint)
- [Plugins](https://iopaint.com/plugins) for post-processing:
- [Segment Anything](https://iopaint.com/plugins/interactive_seg): Accurate and fast interactive object segmentation
- [RemoveBG](https://iopaint.com/plugins/rembg): Remove image background or generate masks for foreground objects
- [Anime Segmentation](https://iopaint.com/plugins/anime_seg): Similar to RemoveBG, the model is specifically trained for anime images.
- [RealESRGAN](https://iopaint.com/plugins/RealESRGAN): Super Resolution
- [GFPGAN](https://iopaint.com/plugins/GFPGAN): Face Restoration
- [RestoreFormer](https://iopaint.com/plugins/RestoreFormer): Face Restoration
- [FileManager](https://iopaint.com/features/file_manager): Browse your pictures conveniently and save them directly to the output directory.
- [Native macOS app](https://opticlean.io/) for erase task
- More features at [IOPaint Docs](https://iopaint.com/)

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix" INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint" KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting" POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
ANYTEXT_NAME = "Sanster/AnyText"
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"

View File

@ -12,6 +12,7 @@ from iopaint.const import (
DIFFUSERS_SD_INPAINT_CLASS_NAME, DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_CLASS_NAME, DIFFUSERS_SDXL_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME, DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
ANYTEXT_NAME,
) )
from iopaint.model_info import ModelInfo, ModelType from iopaint.model_info import ModelInfo, ModelType
@ -24,6 +25,10 @@ def cli_download_model(model: str):
logger.info(f"Downloading {model}...") logger.info(f"Downloading {model}...")
models[model].download() models[model].download()
logger.info(f"Done.") logger.info(f"Done.")
elif model == ANYTEXT_NAME:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
else: else:
logger.info(f"Downloading model from Huggingface: {model}") logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
@ -210,6 +215,7 @@ def scan_models() -> List[ModelInfo]:
"StableDiffusionInstructPix2PixPipeline", "StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline", "PaintByExamplePipeline",
"KandinskyV22InpaintPipeline", "KandinskyV22InpaintPipeline",
"AnyText",
]: ]:
model_type = ModelType.DIFFUSERS_OTHER model_type = ModelType.DIFFUSERS_OTHER
else: else:

View File

@ -1,3 +1,4 @@
from .anytext.anytext_model import AnyText
from .controlnet import ControlNet from .controlnet import ControlNet
from .fcf import FcF from .fcf import FcF
from .instruct_pix2pix import InstructPix2Pix from .instruct_pix2pix import InstructPix2Pix
@ -32,4 +33,5 @@ models = {
Kandinsky22.name: Kandinsky22, Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL, SDXL.name: SDXL,
PowerPaint.name: PowerPaint, PowerPaint.name: PowerPaint,
AnyText.name: AnyText,
} }

View File

@ -0,0 +1,73 @@
import torch
from huggingface_hub import hf_hub_download
from iopaint.const import ANYTEXT_NAME
from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
from iopaint.model.base import DiffusionInpaintModel
from iopaint.model.utils import get_torch_dtype, is_local_files_only
from iopaint.schema import InpaintRequest
class AnyText(DiffusionInpaintModel):
name = ANYTEXT_NAME
pad_mod = 64
is_erase_model = False
@staticmethod
def download(local_files_only=False):
hf_hub_download(
repo_id=ANYTEXT_NAME,
filename="model_index.json",
local_files_only=local_files_only,
)
ckpt_path = hf_hub_download(
repo_id=ANYTEXT_NAME,
filename="pytorch_model.fp16.safetensors",
local_files_only=local_files_only,
)
font_path = hf_hub_download(
repo_id=ANYTEXT_NAME,
filename="SourceHanSansSC-Medium.otf",
local_files_only=local_files_only,
)
return ckpt_path, font_path
def init_model(self, device, **kwargs):
local_files_only = is_local_files_only(**kwargs)
ckpt_path, font_path = self.download(local_files_only)
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
self.model = AnyTextPipeline(
ckpt_path=ckpt_path,
font_path=font_path,
device=device,
use_fp16=torch_dtype == torch.float16,
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to inpainting
return: BGR IMAGE
"""
height, width = image.shape[:2]
mask = mask.astype("float32") / 255.0
masked_image = image * (1 - mask)
# list of rgb ndarray
results, rtn_code, rtn_warning = self.model(
image=image,
masked_image=masked_image,
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.sd_steps,
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
height=height,
width=width,
seed=config.sd_seed,
sort_priority="y",
callback=self.callback
)
inpainted_rgb_image = results[0][..., ::-1]
return inpainted_rgb_image

View File

@ -5,20 +5,21 @@ Code: https://github.com/tyxsspa/AnyText
Copyright (c) Alibaba, Inc. and its affiliates. Copyright (c) Alibaba, Inc. and its affiliates.
""" """
import os import os
from pathlib import Path
from iopaint.model.utils import set_seed
from safetensors.torch import load_file
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import torch import torch
import random
import re import re
import numpy as np import numpy as np
import cv2 import cv2
import einops import einops
import time
from PIL import ImageFont from PIL import ImageFont
from iopaint.model.anytext.cldm.model import create_model, load_state_dict from iopaint.model.anytext.cldm.model import create_model, load_state_dict
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
from iopaint.model.anytext.utils import ( from iopaint.model.anytext.utils import (
resize_image,
check_channels, check_channels,
draw_glyph, draw_glyph,
draw_glyph2, draw_glyph2,
@ -29,55 +30,93 @@ BBOX_MAX_NUM = 8
PLACE_HOLDER = "*" PLACE_HOLDER = "*"
max_chars = 20 max_chars = 20
ANYTEXT_CFG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
)
def check_limits(tensor):
float16_min = torch.finfo(torch.float16).min
float16_max = torch.finfo(torch.float16).max
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
is_below_min = (tensor < float16_min).any()
is_above_max = (tensor > float16_max).any()
return is_below_min or is_above_max
class AnyTextPipeline: class AnyTextPipeline:
def __init__(self, cfg_path, model_dir, font_path, device, use_fp16=True): def __init__(self, ckpt_path, font_path, device, use_fp16=True):
self.cfg_path = cfg_path self.cfg_path = ANYTEXT_CFG
self.model_dir = model_dir
self.font_path = font_path self.font_path = font_path
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
self.device = device self.device = device
self.init_model()
self.font = ImageFont.truetype(font_path, size=60)
self.model = create_model(
self.cfg_path,
device=self.device,
use_fp16=self.use_fp16,
)
if self.use_fp16:
self.model = self.model.half()
if Path(ckpt_path).suffix == ".safetensors":
state_dict = load_file(ckpt_path, device="cpu")
else:
state_dict = load_state_dict(ckpt_path, location="cpu")
self.model.load_state_dict(state_dict, strict=False)
self.model = self.model.eval().to(self.device)
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
def __call__(
self,
prompt: str,
negative_prompt: str,
image: np.ndarray,
masked_image: np.ndarray,
num_inference_steps: int,
strength: float,
guidance_scale: float,
height: int,
width: int,
seed: int,
sort_priority: str = "y",
callback=None,
):
""" """
return:
Args:
prompt:
negative_prompt:
image:
masked_image:
num_inference_steps:
strength:
guidance_scale:
height:
width:
seed:
sort_priority: x: left-right, y: top-down
Returns:
result: list of images in numpy.ndarray format result: list of images in numpy.ndarray format
rst_code: 0: normal -1: error 1:warning rst_code: 0: normal -1: error 1:warning
rst_info: string of error or warning rst_info: string of error or warning
debug_info: string for debug, only valid if show_debug=True
""" """
set_seed(seed)
def __call__(self, input_tensor, **forward_params):
tic = time.time()
str_warning = "" str_warning = ""
# get inputs
seed = input_tensor.get("seed", -1)
if seed == -1:
seed = random.randint(0, 99999999)
# seed_everything(seed)
prompt = input_tensor.get("prompt")
draw_pos = input_tensor.get("draw_pos")
ori_image = input_tensor.get("ori_image")
mode = forward_params.get("mode") mode = "text-editing"
sort_priority = forward_params.get("sort_priority", "") revise_pos = False
show_debug = forward_params.get("show_debug", False) img_count = 1
revise_pos = forward_params.get("revise_pos", False) ddim_steps = num_inference_steps
img_count = forward_params.get("image_count", 4) w = width
ddim_steps = forward_params.get("ddim_steps", 20) h = height
w = forward_params.get("image_width", 512) strength = strength
h = forward_params.get("image_height", 512) cfg_scale = guidance_scale
strength = forward_params.get("strength", 1.0) eta = 0.0
cfg_scale = forward_params.get("cfg_scale", 9.0)
eta = forward_params.get("eta", 0.0)
a_prompt = forward_params.get(
"a_prompt",
"best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks",
)
n_prompt = forward_params.get(
"n_prompt",
"low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
)
prompt, texts = self.modify_prompt(prompt) prompt, texts = self.modify_prompt(prompt)
if prompt is None and texts is None: if prompt is None and texts is None:
@ -91,43 +130,44 @@ class AnyTextPipeline:
if mode in ["text-generation", "gen"]: if mode in ["text-generation", "gen"]:
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
elif mode in ["text-editing", "edit"]: elif mode in ["text-editing", "edit"]:
if draw_pos is None or ori_image is None: if masked_image is None or image is None:
return ( return (
None, None,
-1, -1,
"Reference image and position image are needed for text editing!", "Reference image and position image are needed for text editing!",
"", "",
) )
if isinstance(ori_image, str): if isinstance(image, str):
ori_image = cv2.imread(ori_image)[..., ::-1] image = cv2.imread(image)[..., ::-1]
assert ( assert image is not None, f"Can't read ori_image image from{image}!"
ori_image is not None elif isinstance(image, torch.Tensor):
), f"Can't read ori_image image from{ori_image}!" image = image.cpu().numpy()
elif isinstance(ori_image, torch.Tensor):
ori_image = ori_image.cpu().numpy()
else: else:
assert isinstance( assert isinstance(
ori_image, np.ndarray image, np.ndarray
), f"Unknown format of ori_image: {type(ori_image)}" ), f"Unknown format of ori_image: {type(image)}"
edit_image = ori_image.clip(1, 255) # for mask reason edit_image = image.clip(1, 255) # for mask reason
edit_image = check_channels(edit_image) edit_image = check_channels(edit_image)
edit_image = resize_image( # edit_image = resize_image(
edit_image, max_length=768 # edit_image, max_length=768
) # make w h multiple of 64, resize if w or h > max_length # ) # make w h multiple of 64, resize if w or h > max_length
h, w = edit_image.shape[:2] # change h, w by input ref_img h, w = edit_image.shape[:2] # change h, w by input ref_img
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg) # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
if draw_pos is None: if masked_image is None:
pos_imgs = np.zeros((w, h, 1)) pos_imgs = np.zeros((w, h, 1))
if isinstance(draw_pos, str): if isinstance(masked_image, str):
draw_pos = cv2.imread(draw_pos)[..., ::-1] masked_image = cv2.imread(masked_image)[..., ::-1]
assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!" assert (
pos_imgs = 255 - draw_pos masked_image is not None
elif isinstance(draw_pos, torch.Tensor): ), f"Can't read draw_pos image from{masked_image}!"
pos_imgs = draw_pos.cpu().numpy() pos_imgs = 255 - masked_image
elif isinstance(masked_image, torch.Tensor):
pos_imgs = masked_image.cpu().numpy()
else: else:
assert isinstance( assert isinstance(
draw_pos, np.ndarray masked_image, np.ndarray
), f"Unknown format of draw_pos: {type(draw_pos)}" ), f"Unknown format of draw_pos: {type(masked_image)}"
pos_imgs = 255 - masked_image
pos_imgs = pos_imgs[..., 0:1] pos_imgs = pos_imgs[..., 0:1]
pos_imgs = cv2.convertScaleAbs(pos_imgs) pos_imgs = cv2.convertScaleAbs(pos_imgs)
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
@ -139,11 +179,8 @@ class AnyTextPipeline:
if n_lines == 1 and texts[0] == " ": if n_lines == 1 and texts[0] == " ":
pass # text-to-image without text pass # text-to-image without text
else: else:
return ( raise RuntimeError(
None, f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
-1,
f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!",
"",
) )
elif len(pos_imgs) > n_lines: elif len(pos_imgs) > n_lines:
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
@ -250,12 +287,16 @@ class AnyTextPipeline:
cond = self.model.get_learned_conditioning( cond = self.model.get_learned_conditioning(
dict( dict(
c_concat=[hint], c_concat=[hint],
c_crossattn=[[prompt + " , " + a_prompt] * img_count], c_crossattn=[[prompt] * img_count],
text_info=info, text_info=info,
) )
) )
un_cond = self.model.get_learned_conditioning( un_cond = self.model.get_learned_conditioning(
dict(c_concat=[hint], c_crossattn=[[n_prompt] * img_count], text_info=info) dict(
c_concat=[hint],
c_crossattn=[[negative_prompt] * img_count],
text_info=info,
)
) )
shape = (4, h // 8, w // 8) shape = (4, h // 8, w // 8)
self.model.control_scales = [strength] * 13 self.model.control_scales = [strength] * 13
@ -268,6 +309,7 @@ class AnyTextPipeline:
eta=eta, eta=eta,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=un_cond, unconditional_conditioning=un_cond,
callback=callback
) )
if self.use_fp16: if self.use_fp16:
samples = samples.half() samples = samples.half()
@ -280,52 +322,18 @@ class AnyTextPipeline:
.astype(np.uint8) .astype(np.uint8)
) )
results = [x_samples[i] for i in range(img_count)] results = [x_samples[i] for i in range(img_count)]
if ( # if (
mode == "edit" and False # mode == "edit" and False
): # replace backgound in text editing but not ideal yet # ): # replace backgound in text editing but not ideal yet
results = [r * np_hint + edit_image * (1 - np_hint) for r in results] # results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
results = [r.clip(0, 255).astype(np.uint8) for r in results] # results = [r.clip(0, 255).astype(np.uint8) for r in results]
if len(gly_pos_imgs) > 0 and show_debug: # if len(gly_pos_imgs) > 0 and show_debug:
glyph_bs = np.stack(gly_pos_imgs, axis=2) # glyph_bs = np.stack(gly_pos_imgs, axis=2)
glyph_img = np.sum(glyph_bs, axis=2) * 255 # glyph_img = np.sum(glyph_bs, axis=2) * 255
glyph_img = glyph_img.clip(0, 255).astype(np.uint8) # glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
results += [np.repeat(glyph_img, 3, axis=2)] # results += [np.repeat(glyph_img, 3, axis=2)]
# debug_info
if not show_debug:
debug_info = ""
else:
input_prompt = prompt
for t in texts:
input_prompt = input_prompt.replace("*", f'"{t}"', 1)
debug_info = f'<span style="color:black;font-size:18px">Prompt: </span>{input_prompt}<br> \
<span style="color:black;font-size:18px">Size: </span>{w}x{h}<br> \
<span style="color:black;font-size:18px">Image Count: </span>{img_count}<br> \
<span style="color:black;font-size:18px">Seed: </span>{seed}<br> \
<span style="color:black;font-size:18px">Use FP16: </span>{self.use_fp16}<br> \
<span style="color:black;font-size:18px">Cost Time: </span>{(time.time()-tic):.2f}s'
rst_code = 1 if str_warning else 0 rst_code = 1 if str_warning else 0
return results, rst_code, str_warning, debug_info return results, rst_code, str_warning
def init_model(self):
font_path = self.font_path
self.font = ImageFont.truetype(font_path, size=60)
cfg_path = self.cfg_path
ckpt_path = os.path.join(self.model_dir, "anytext_v1.1.ckpt")
clip_path = os.path.join(self.model_dir, "clip-vit-large-patch14")
self.model = create_model(
cfg_path,
device=self.device,
cond_stage_path=clip_path,
use_fp16=self.use_fp16,
)
if self.use_fp16:
self.model = self.model.half()
self.model.load_state_dict(
load_state_dict(ckpt_path, location=self.device), strict=False
)
self.model.eval()
self.model = self.model.to(self.device)
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
def modify_prompt(self, prompt): def modify_prompt(self, prompt):
prompt = prompt.replace("", '"') prompt = prompt.replace("", '"')
@ -360,9 +368,9 @@ class AnyTextPipeline:
component = np.zeros_like(img) component = np.zeros_like(img)
component[labels == label] = 255 component[labels == label] = 255
components.append((component, centroids[label])) components.append((component, centroids[label]))
if sort_priority == "": if sort_priority == "y":
fir, sec = 1, 0 # top-down first fir, sec = 1, 0 # top-down first
elif sort_priority == "": elif sort_priority == "x":
fir, sec = 0, 1 # left-right first fir, sec = 0, 1 # left-right first
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
sorted_components = [c[0] for c in components] sorted_components = [c[0] for c in components]

View File

@ -95,5 +95,5 @@ model:
cond_stage_config: cond_stage_config:
target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3 target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
params: params:
version: ./models/clip-vit-large-patch14 version: openai/clip-vit-large-patch14
use_vision: false # v6 use_vision: false # v6

View File

@ -254,7 +254,7 @@ class DDIMSampler(object):
) )
img, pred_x0 = outs img, pred_x0 = outs
if callback: if callback:
callback(i) callback(None, i, None, None)
if img_callback: if img_callback:
img_callback(pred_x0, i) img_callback(pred_x0, i)

View File

@ -26,11 +26,11 @@ def load_state_dict(ckpt_path, location="cpu"):
def create_model(config_path, device, cond_stage_path=None, use_fp16=False): def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
config = OmegaConf.load(config_path) config = OmegaConf.load(config_path)
if cond_stage_path: # if cond_stage_path:
config.model.params.cond_stage_config.params.version = ( # config.model.params.cond_stage_config.params.version = (
cond_stage_path # use pre-downloaded ckpts, in case blocked # cond_stage_path # use pre-downloaded ckpts, in case blocked
) # )
config.model.params.cond_stage_config.params.device = device config.model.params.cond_stage_config.params.device = str(device)
if use_fp16: if use_fp16:
config.model.params.use_fp16 = True config.model.params.use_fp16 = True
config.model.params.control_stage_config.params.use_fp16 = True config.model.params.control_stage_config.params.use_fp16 = True

View File

@ -2,7 +2,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPVisionModelWithProjection from transformers import (
T5Tokenizer,
T5EncoderModel,
CLIPTokenizer,
CLIPTextModel,
AutoProcessor,
CLIPVisionModelWithProjection,
)
from iopaint.model.anytext.ldm.util import count_params from iopaint.model.anytext.ldm.util import count_params
@ -18,7 +25,9 @@ def _expand_mask(mask, dtype, tgt_len=None):
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype): def _build_causal_attention_mask(bsz, seq_len, dtype):
@ -30,6 +39,7 @@ def _build_causal_attention_mask(bsz, seq_len, dtype):
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -39,13 +49,12 @@ class AbstractEncoder(nn.Module):
class IdentityEncoder(AbstractEncoder): class IdentityEncoder(AbstractEncoder):
def encode(self, x): def encode(self, x):
return x return x
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
super().__init__() super().__init__()
self.key = key self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim) self.embedding = nn.Embedding(n_classes, embed_dim)
@ -57,15 +66,17 @@ class ClassEmbedder(nn.Module):
key = self.key key = self.key
# this is for use in crossattn # this is for use in crossattn
c = batch[key][:, None] c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout: if self.ucg_rate > 0.0 and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long() c = c.long()
c = self.embedding(c) c = self.embedding(c)
return c return c
def get_unconditional_conditioning(self, bs, device="cuda"): def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc} uc = {self.key: uc}
return uc return uc
@ -79,7 +90,10 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder): class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text""" """Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__() super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version) self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version)
@ -90,13 +104,20 @@ class FrozenT5Embedder(AbstractEncoder):
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens) outputs = self.transformer(input_ids=tokens)
@ -109,13 +130,18 @@ class FrozenT5Embedder(AbstractEncoder):
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last", LAYERS = ["last", "pooled", "hidden"]
"pooled",
"hidden" def __init__(
] self,
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, version="openai/clip-vit-large-patch14",
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version) self.tokenizer = CLIPTokenizer.from_pretrained(version)
@ -137,10 +163,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False param.requires_grad = False
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == "hidden"
)
if self.layer == "last": if self.layer == "last":
z = outputs.last_hidden_state z = outputs.last_hidden_state
elif self.layer == "pooled": elif self.layer == "pooled":
@ -153,77 +188,24 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", def __init__(
clip_max_length=77, t5_max_length=77): self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__() super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " print(
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
)
def encode(self, text): def encode(self, text):
return self(text) return self(text)
@ -236,7 +218,15 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
class FrozenCLIPEmbedderT3(AbstractEncoder): class FrozenCLIPEmbedderT3(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False):
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
use_vision=False,
):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version) self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version)
@ -255,7 +245,11 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
inputs_embeds=None, inputs_embeds=None,
embedding_manager=None, embedding_manager=None,
): ):
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, :seq_length] position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None: if inputs_embeds is None:
@ -266,7 +260,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
embeddings = inputs_embeds + position_embeddings embeddings = inputs_embeds + position_embeddings
return embeddings return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
self.transformer.text_model.embeddings
)
def encoder_forward( def encoder_forward(
self, self,
@ -277,11 +273,19 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -301,7 +305,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
return hidden_states return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward( def text_encoder_forward(
self, self,
@ -313,22 +319,34 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
return_dict=None, return_dict=None,
embedding_manager=None, embedding_manager=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_hidden_states = ( output_attentions
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None: if input_ids is None:
raise ValueError("You have to specify either input_ids") raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( causal_attention_mask = _build_causal_attention_mask(
hidden_states.device bsz, seq_len, hidden_states.dtype
) ).to(hidden_states.device)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@ -344,7 +362,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
return last_hidden_state return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward( def transformer_forward(
self, self,
@ -363,7 +383,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
embedding_manager=embedding_manager embedding_manager=embedding_manager,
) )
self.transformer.forward = transformer_forward.__get__(self.transformer) self.transformer.forward = transformer_forward.__get__(self.transformer)
@ -374,8 +394,15 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
param.requires_grad = False param.requires_grad = False
def forward(self, text, **kwargs): def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
return z return z

View File

@ -1,3 +1,6 @@
import cv2
import os
from anytext_pipeline import AnyTextPipeline from anytext_pipeline import AnyTextPipeline
from utils import save_images from utils import save_images
@ -5,48 +8,38 @@ seed = 66273235
# seed_everything(seed) # seed_everything(seed)
pipe = AnyTextPipeline( pipe = AnyTextPipeline(
cfg_path="/Users/cwq/code/github/AnyText/anytext/models_yaMl/anytext_sd15.yaml", ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt",
model_dir="/Users/cwq/.cache/modelscope/hub/damo/cv_anytext_text_generation_editing",
# font_path="/Users/cwq/code/github/AnyText/anytext/font/Arial_Unicode.ttf",
# font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-VF.ttf",
font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf", font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
use_fp16=False, use_fp16=False,
device="mps", device="mps",
) )
img_save_folder = "SaveImages" img_save_folder = "SaveImages"
params = { rgb_image = cv2.imread(
"show_debug": True, "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg"
"image_count": 2, )[..., ::-1]
"ddim_steps": 20,
}
# # 1. text generation masked_image = cv2.imread(
# mode = "text-generation" "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png"
# input_data = { )[..., ::-1]
# "prompt": 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream',
# "seed": seed, rgb_image = cv2.resize(rgb_image, (512, 512))
# "draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/gen9.png", masked_image = cv2.resize(masked_image, (512, 512))
# }
# results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params) # results: list of rgb ndarray
# if rtn_code >= 0: results, rtn_code, rtn_warning = pipe(
# save_images(results, img_save_folder) prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
# print(f"Done, result images are saved in: {img_save_folder}") negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
# if rtn_warning: image=rgb_image,
# print(rtn_warning) masked_image=masked_image,
# num_inference_steps=20,
# exit() strength=1.0,
# 2. text editing guidance_scale=9.0,
mode = "text-editing" height=rgb_image.shape[0],
input_data = { width=rgb_image.shape[1],
"prompt": 'A cake with colorful characters that reads "EVERYDAY"', seed=seed,
"seed": seed, sort_priority="y",
"draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png", )
"ori_image": "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg",
}
results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
if rtn_code >= 0: if rtn_code >= 0:
save_images(results, img_save_folder) save_images(results, img_save_folder)
print(f"Done, result images are saved in: {img_save_folder}") print(f"Done, result images are saved in: {img_save_folder}")
if rtn_warning:
print(rtn_warning)

View File

@ -9,6 +9,7 @@ from iopaint.const import (
INSTRUCT_PIX2PIX_NAME, INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME, KANDINSKY22_NAME,
POWERPAINT_NAME, POWERPAINT_NAME,
ANYTEXT_NAME,
) )
from iopaint.schema import ModelType from iopaint.schema import ModelType
@ -31,6 +32,7 @@ class ModelInfo(BaseModel):
INSTRUCT_PIX2PIX_NAME, INSTRUCT_PIX2PIX_NAME,
KANDINSKY22_NAME, KANDINSKY22_NAME,
POWERPAINT_NAME, POWERPAINT_NAME,
ANYTEXT_NAME,
] ]
@computed_field @computed_field
@ -58,7 +60,7 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL, ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [POWERPAINT_NAME] ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
@computed_field @computed_field
@property @property

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

View File

@ -0,0 +1,45 @@
import os
from iopaint.tests.utils import check_device, get_config, assert_equal
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path
import pytest
import torch
from iopaint.model_manager import ModelManager
from iopaint.schema import HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
@pytest.mark.parametrize("device", ["cuda", "mps"])
def test_anytext(device):
sd_steps = check_device(device)
model = ModelManager(
name="Sanster/AnyText",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
)
cfg = get_config(
strategy=HDStrategy.ORIGINAL,
prompt='Characters written in chalk on the blackboard that says "DADDY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
sd_steps=sd_steps,
sd_guidance_scale=9.0,
sd_seed=66273235,
sd_match_histograms=True
)
assert_equal(
model,
cfg,
f"anytext.png",
img_p=current_dir / "anytext_ref.jpg",
mask_p=current_dir / "anytext_mask.jpg",
)

View File

@ -16,7 +16,12 @@ import { Separator } from "../ui/separator"
import { Button, ImageUploadButton } from "../ui/button" import { Button, ImageUploadButton } from "../ui/button"
import { Slider } from "../ui/slider" import { Slider } from "../ui/slider"
import { useImage } from "@/hooks/useImage" import { useImage } from "@/hooks/useImage"
import { INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, POWERPAINT } from "@/lib/const" import {
ANYTEXT,
INSTRUCT_PIX2PIX,
PAINT_BY_EXAMPLE,
POWERPAINT,
} from "@/lib/const"
import { RowContainer, LabelTitle } from "./LabelTitle" import { RowContainer, LabelTitle } from "./LabelTitle"
import { Minus, Plus, Upload } from "lucide-react" import { Minus, Plus, Upload } from "lucide-react"
import { useClickAway } from "react-use" import { useClickAway } from "react-use"
@ -661,6 +666,10 @@ const DiffusionOptions = () => {
} }
const renderSampler = () => { const renderSampler = () => {
if (settings.model.name === ANYTEXT) {
return null
}
return ( return (
<RowContainer> <RowContainer>
<LabelTitle text="Sampler" /> <LabelTitle text="Sampler" />

View File

@ -15,6 +15,7 @@ export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example"
export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix" export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint" export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
export const POWERPAINT = "Sanster/PowerPaint-V1-stable-diffusion-inpainting" export const POWERPAINT = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
export const ANYTEXT = "Sanster/AnyText"
export const DEFAULT_NEGATIVE_PROMPT = export const DEFAULT_NEGATIVE_PROMPT =
"out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature" "out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature"