new file: inpaint/__init__.py

new file:   inpaint/__main__.py
	new file:   inpaint/api.py
	new file:   inpaint/batch_processing.py
	new file:   inpaint/benchmark.py
	new file:   inpaint/cli.py
	new file:   inpaint/const.py
	new file:   inpaint/download.py
	new file:   inpaint/file_manager/__init__.py
	new file:   inpaint/file_manager/file_manager.py
	new file:   inpaint/file_manager/storage_backends.py
	new file:   inpaint/file_manager/utils.py
	new file:   inpaint/helper.py
	new file:   inpaint/installer.py
	new file:   inpaint/model/__init__.py
	new file:   inpaint/model/anytext/__init__.py
	new file:   inpaint/model/anytext/anytext_model.py
	new file:   inpaint/model/anytext/anytext_pipeline.py
	new file:   inpaint/model/anytext/anytext_sd15.yaml
	new file:   inpaint/model/anytext/cldm/__init__.py
	new file:   inpaint/model/anytext/cldm/cldm.py
	new file:   inpaint/model/anytext/cldm/ddim_hacked.py
	new file:   inpaint/model/anytext/cldm/embedding_manager.py
	new file:   inpaint/model/anytext/cldm/hack.py
	new file:   inpaint/model/anytext/cldm/model.py
	new file:   inpaint/model/anytext/cldm/recognizer.py
	new file:   inpaint/model/anytext/ldm/__init__.py
	new file:   inpaint/model/anytext/ldm/models/__init__.py
	new file:   inpaint/model/anytext/ldm/models/autoencoder.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/__init__.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/ddim.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/ddpm.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/plms.py
	new file:   inpaint/model/anytext/ldm/models/diffusion/sampling_util.py
	new file:   inpaint/model/anytext/ldm/modules/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/attention.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/model.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py
	new file:   inpaint/model/anytext/ldm/modules/diffusionmodules/util.py
	new file:   inpaint/model/anytext/ldm/modules/distributions/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/distributions/distributions.py
	new file:   inpaint/model/anytext/ldm/modules/ema.py
	new file:   inpaint/model/anytext/ldm/modules/encoders/__init__.py
	new file:   inpaint/model/anytext/ldm/modules/encoders/modules.py
	new file:   inpaint/model/anytext/ldm/util.py
	new file:   inpaint/model/anytext/main.py
	new file:   inpaint/model/anytext/ocr_recog/RNN.py
	new file:   inpaint/model/anytext/ocr_recog/RecCTCHead.py
	new file:   inpaint/model/anytext/ocr_recog/RecModel.py
	new file:   inpaint/model/anytext/ocr_recog/RecMv1_enhance.py
	new file:   inpaint/model/anytext/ocr_recog/RecSVTR.py
	new file:   inpaint/model/anytext/ocr_recog/__init__.py
	new file:   inpaint/model/anytext/ocr_recog/common.py
	new file:   inpaint/model/anytext/ocr_recog/en_dict.txt
	new file:   inpaint/model/anytext/ocr_recog/ppocr_keys_v1.txt
	new file:   inpaint/model/anytext/utils.py
	new file:   inpaint/model/base.py
	new file:   inpaint/model/brushnet/__init__.py
	new file:   inpaint/model/brushnet/brushnet.py
	new file:   inpaint/model/brushnet/brushnet_unet_forward.py
	new file:   inpaint/model/brushnet/brushnet_wrapper.py
	new file:   inpaint/model/brushnet/pipeline_brushnet.py
	new file:   inpaint/model/brushnet/unet_2d_blocks.py
	new file:   inpaint/model/controlnet.py
	new file:   inpaint/model/ddim_sampler.py
	new file:   inpaint/model/fcf.py
	new file:   inpaint/model/helper/__init__.py
	new file:   inpaint/model/helper/controlnet_preprocess.py
	new file:   inpaint/model/helper/cpu_text_encoder.py
	new file:   inpaint/model/helper/g_diffuser_bot.py
	new file:   inpaint/model/instruct_pix2pix.py
	new file:   inpaint/model/kandinsky.py
	new file:   inpaint/model/lama.py
	new file:   inpaint/model/ldm.py
	new file:   inpaint/model/manga.py
	new file:   inpaint/model/mat.py
	new file:   inpaint/model/mi_gan.py
	new file:   inpaint/model/opencv2.py
	new file:   inpaint/model/original_sd_configs/__init__.py
	new file:   inpaint/model/original_sd_configs/sd_xl_base.yaml
	new file:   inpaint/model/original_sd_configs/sd_xl_refiner.yaml
	new file:   inpaint/model/original_sd_configs/v1-inference.yaml
	new file:   inpaint/model/original_sd_configs/v2-inference-v.yaml
	new file:   inpaint/model/paint_by_example.py
	new file:   inpaint/model/plms_sampler.py
	new file:   inpaint/model/power_paint/__init__.py
	new file:   inpaint/model/power_paint/pipeline_powerpaint.py
	new file:   inpaint/model/power_paint/power_paint.py
	new file:   inpaint/model/power_paint/power_paint_v2.py
	new file:   inpaint/model/power_paint/powerpaint_tokenizer.py
This commit is contained in:
root 2024-08-20 21:17:33 +02:00
parent 6b3f4ae1c0
commit 70af4845af
232 changed files with 54070 additions and 0 deletions

23
inpaint/__init__.py Normal file
View File

@ -0,0 +1,23 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
os.environ["LRU_CACHE_CAPACITY"] = "1"
# prevent CPU memory leak when run model on GPU
# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
import warnings
warnings.simplefilter("ignore", UserWarning)
def entry_point():
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
from inpaint.cli import typer_app
typer_app()

4
inpaint/__main__.py Normal file
View File

@ -0,0 +1,4 @@
from inpaint import entry_point
if __name__ == "__main__":
entry_point()

398
inpaint/api.py Normal file
View File

@ -0,0 +1,398 @@
import asyncio
import os
import threading
import time
import traceback
from pathlib import Path
from typing import Optional, Dict, List
import cv2
import numpy as np
import socketio
import torch
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
import uvicorn
from PIL import Image
from fastapi import APIRouter, FastAPI, Request, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse, Response
from fastapi.staticfiles import StaticFiles
from loguru import logger
from socketio import AsyncServer
from inpaint.file_manager import FileManager
from inpaint.helper import (
load_img,
decode_base64_to_image,
pil_to_bytes,
numpy_to_bytes,
concat_alpha_channel,
gen_frontend_mask,
adjust_mask,
)
from inpaint.model.utils import torch_gc
from inpaint.model_manager import ModelManager
from inpaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
from inpaint.plugins.base_plugin import BasePlugin
from inpaint.plugins.remove_bg import RemoveBG
from inpaint.schema import (
GenInfoResponse,
ApiConfig,
ServerConfigResponse,
SwitchModelRequest,
InpaintRequest,
RunPluginRequest,
SDSampler,
PluginInfo,
AdjustMaskRequest,
RemoveBGModel,
SwitchPluginModelRequest,
ModelInfo,
InteractiveSegModel,
RealESRGANModel,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
WEB_APP_DIR = CURRENT_DIR / "web_app"
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get("detail", ""),
"body": vars(e).get("body", ""),
"errors": str(e),
}
if not isinstance(
e, HTTPException
): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(
show_locals=True,
max_frames=2,
extra_lines=1,
suppress=[anyio, starlette],
word_wrap=False,
width=min([console.width, 200]),
)
else:
traceback.print_exc()
return JSONResponse(
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
)
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_origins": ["*"],
"allow_credentials": True,
"expose_headers": ["X-Seed"],
}
app.add_middleware(CORSMiddleware, **cors_options)
global_sio: AsyncServer = None
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
return {}
class Api:
def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app
self.config = config
self.router = APIRouter()
self.queue_lock = threading.Lock()
api_middleware(self.app)
self.file_manager = self._build_file_manager()
self.plugins = self._build_plugins()
self.model_manager = self._build_model_manager()
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
response_model=ServerConfigResponse)
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
def api_save_image(self, file: UploadFile):
filename = file.filename
origin_image_bytes = file.file.read()
with open(self.config.output_dir / filename, "wb") as fw:
fw.write(origin_image_bytes)
def api_current_model(self) -> ModelInfo:
return self.model_manager.current_model
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
if req.name == self.model_manager.name:
return self.model_manager.current_model
self.model_manager.switch(req.name)
return self.model_manager.current_model
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
if req.plugin_name in self.plugins:
self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name
if req.plugin_name == RealESRGANUpscaler.name:
self.config.realesrgan_model = req.model_name
if req.plugin_name == InteractiveSeg.name:
self.config.interactive_seg_model = req.model_name
torch_gc()
def api_server_config(self) -> ServerConfigResponse:
plugins = []
for it in self.plugins.values():
plugins.append(
PluginInfo(
name=it.name,
support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask,
)
)
return ServerConfigResponse(
plugins=plugins,
modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(),
realesrganModel=self.config.realesrgan_model,
realesrganModels=RealESRGANModel.values(),
interactiveSegModel=self.config.interactive_seg_model,
interactiveSegModels=InteractiveSegModel.values(),
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,
controlnetMethod=self.model_manager.controlnet_method,
disableModelSwitch=False,
isDesktop=False,
samplers=self.api_samplers(),
)
def api_input_image(self) -> FileResponse:
if self.config.input and self.config.input.is_file():
return FileResponse(self.config.input)
raise HTTPException(status_code=404, detail="Input image not found")
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
_, _, info = load_img(file.file.read(), return_info=True)
parts = info.get("parameters", "").split("Negative prompt: ")
prompt = parts[0].strip()
negative_prompt = ""
if len(parts) > 1:
negative_prompt = parts[1].split("\n")[0].strip()
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
def api_inpaint(self, req: InpaintRequest):
image, alpha_channel, infos = decode_base64_to_image(req.image)
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
raise HTTPException(
400,
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
)
if req.paint_by_example_example_image:
paint_by_example_image, _, _ = decode_base64_to_image(
req.paint_by_example_example_image
)
start = time.time()
rgb_np_img = self.model_manager(image, mask, req)
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
torch_gc()
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
ext = "png"
res_img_bytes = pil_to_bytes(
Image.fromarray(rgb_res),
ext=ext,
quality=self.config.quality,
infos=infos,
)
asyncio.run(self.sio.emit("diffusion_finish"))
return Response(
content=res_img_bytes,
media_type=f"image/{ext}",
headers={"X-Seed": str(req.sd_seed)},
)
def api_run_plugin_gen_image(self, req: RunPluginRequest):
ext = "png"
if req.name not in self.plugins:
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_image:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
torch_gc()
if bgr_or_rgba_np_img.shape[2] == 4:
rgba_np_img = bgr_or_rgba_np_img
else:
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
return Response(
content=pil_to_bytes(
Image.fromarray(rgba_np_img),
ext=ext,
quality=self.config.quality,
infos=infos,
),
media_type=f"image/{ext}",
)
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
if req.name not in self.plugins:
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_mask:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
torch_gc()
res_mask = gen_frontend_mask(bgr_or_gray_mask)
return Response(
content=numpy_to_bytes(res_mask, "png"),
media_type="image/png",
)
def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()]
def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = adjust_mask(mask, req.kernel_size, req.operate)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
def launch(self):
self.app.include_router(self.router)
uvicorn.run(
self.combined_asgi_app,
host=self.config.host,
port=self.config.port,
timeout_keep_alive=999999999,
)
def _build_file_manager(self) -> Optional[FileManager]:
if self.config.input and self.config.input.is_dir():
logger.info(
f"Input is directory, initialize file manager {self.config.input}"
)
return FileManager(
app=self.app,
input_dir=self.config.input,
mask_dir=self.config.mask_dir,
output_dir=self.config.output_dir,
)
return None
def _build_plugins(self) -> Dict[str, BasePlugin]:
return build_plugins(
self.config.enable_interactive_seg,
self.config.interactive_seg_model,
self.config.interactive_seg_device,
self.config.enable_remove_bg,
self.config.remove_bg_model,
self.config.enable_anime_seg,
self.config.enable_realesrgan,
self.config.realesrgan_device,
self.config.realesrgan_model,
self.config.enable_gfpgan,
self.config.gfpgan_device,
self.config.enable_restoreformer,
self.config.restoreformer_device,
self.config.no_half,
)
def _build_model_manager(self):
return ModelManager(
name=self.config.model,
device=torch.device(self.config.device),
no_half=self.config.no_half,
low_mem=self.config.low_mem,
disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder,
local_files_only=self.config.local_files_only,
cpu_offload=self.config.cpu_offload,
callback=diffuser_callback,
)

128
inpaint/batch_processing.py Normal file
View File

@ -0,0 +1,128 @@
import json
from pathlib import Path
from typing import Dict, Optional
import cv2
import numpy as np
from PIL import Image
from loguru import logger
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TimeElapsedColumn,
MofNCompleteColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from inpaint.helper import pil_to_bytes
from inpaint.model.utils import torch_gc
from inpaint.model_manager import ModelManager
from inpaint.schema import InpaintRequest
def glob_images(path: Path) -> Dict[str, Path]:
# png/jpg/jpeg
if path.is_file():
return {path.stem: path}
elif path.is_dir():
res = {}
for it in path.glob("*.*"):
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
res[it.stem] = it
return res
def batch_inpaint(
model: str,
device,
image: Path,
mask: Path,
output: Path,
config: Optional[Path] = None,
concat: bool = False,
):
if image.is_dir() and output.is_file():
logger.error(
"invalid --output: when image is a directory, output should be a directory"
)
exit(-1)
output.mkdir(parents=True, exist_ok=True)
image_paths = glob_images(image)
mask_paths = glob_images(mask)
if len(image_paths) == 0:
logger.error("invalid --image: empty image folder")
exit(-1)
if len(mask_paths) == 0:
logger.error("invalid --mask: empty mask folder")
exit(-1)
if config is None:
inpaint_request = InpaintRequest()
logger.info(f"Using default config: {inpaint_request}")
else:
with open(config, "r", encoding="utf-8") as f:
inpaint_request = InpaintRequest(**json.load(f))
logger.info(f"Using config: {inpaint_request}")
model_manager = ModelManager(name=model, device=device)
first_mask = list(mask_paths.values())[0]
console = Console()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
console=console,
transient=False,
) as progress:
task = progress.add_task("Batch processing...", total=len(image_paths))
for stem, image_p in image_paths.items():
if stem not in mask_paths and mask.is_dir():
progress.log(f"mask for {image_p} not found")
progress.update(task, advance=1)
continue
mask_p = mask_paths.get(stem, first_mask)
infos = Image.open(image_p).info
img = np.array(Image.open(image_p).convert("RGB"))
mask_img = np.array(Image.open(mask_p).convert("L"))
if mask_img.shape[:2] != img.shape[:2]:
progress.log(
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
)
mask_img = cv2.resize(
mask_img,
(img.shape[1], img.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
mask_img[mask_img >= 127] = 255
mask_img[mask_img < 127] = 0
# bgr
inpaint_result = model_manager(img, mask_img, inpaint_request)
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
if concat:
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
save_p = output / f"{stem}.png"
with open(save_p, "wb") as fw:
fw.write(img_bytes)
progress.update(task, advance=1)
torch_gc()
# pid = psutil.Process().pid
# memory_info = psutil.Process(pid).memory_info()
# memory_in_mb = memory_info.rss / (1024 * 1024)
# print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")

109
inpaint/benchmark.py Normal file
View File

@ -0,0 +1,109 @@
#!/usr/bin/env python3
import argparse
import os
import time
import numpy as np
import nvidia_smi
import psutil
import torch
from inpaint.model_manager import ModelManager
from inpaint.schema import InpaintRequest, HDStrategy, SDSampler
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
NUM_THREADS = str(4)
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
def run_model(model, size):
# RGB
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
mask = np.random.randint(0, 255, size).astype(np.uint8)
config = InpaintRequest(
ldm_steps=2,
hd_strategy=HDStrategy.ORIGINAL,
hd_strategy_crop_margin=128,
hd_strategy_crop_trigger_size=128,
hd_strategy_resize_limit=128,
prompt="a fox is sitting on a bench",
sd_steps=5,
sd_sampler=SDSampler.ddim,
)
model(image, mask, config)
def benchmark(model, times: int, empty_cache: bool):
sizes = [(512, 512)]
nvidia_smi.nvmlInit()
device_id = 0
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
def format(metrics):
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
process = psutil.Process(os.getpid())
# 每个 size 给出显存和内存占用的指标
for size in sizes:
torch.cuda.empty_cache()
time_metrics = []
cpu_metrics = []
memory_metrics = []
gpu_memory_metrics = []
for _ in range(times):
start = time.time()
run_model(model, size)
torch.cuda.synchronize()
# cpu_metrics.append(process.cpu_percent())
time_metrics.append((time.time() - start) * 1000)
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
gpu_memory_metrics.append(
nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
)
print(f"size: {size}".center(80, "-"))
# print(f"cpu: {format(cpu_metrics)}")
print(f"latency: {format(time_metrics)}ms")
print(f"memory: {format(memory_metrics)} MB")
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
nvidia_smi.nvmlShutdown()
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--name")
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--times", default=10, type=int)
parser.add_argument("--empty-cache", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = get_args_parser()
device = torch.device(args.device)
model = ModelManager(
name=args.name,
device=device,
disable_nsfw=True,
sd_cpu_textencoder=True,
)
benchmark(model, args.times, args.empty_cache)

232
inpaint/cli.py Normal file
View File

@ -0,0 +1,232 @@
import webbrowser
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
import typer
from fastapi import FastAPI
from loguru import logger
from typer import Option
from typer_config import use_json_config
from inpaint.const import *
from inpaint.runtime import setup_model_dir, dump_environment_info, check_device
from inpaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
@typer_app.command(help="Install all plugins dependencies")
def install_plugins_packages():
from inpaint.installer import install_plugins_package
install_plugins_package()
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
def download(
model: str = Option(
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR,
help=MODEL_DIR_HELP,
file_okay=False,
callback=setup_model_dir,
),
):
from inpaint.download import cli_download_model
cli_download_model(model)
@typer_app.command(name="list", help="List downloaded models")
def list_model(
model_dir: Path = Option(
DEFAULT_MODEL_DIR,
help=MODEL_DIR_HELP,
file_okay=False,
callback=setup_model_dir,
),
):
from inpaint.download import scan_models
scanned_models = scan_models()
for it in scanned_models:
print(it.name)
@typer_app.command(help="Batch processing images")
def run(
model: str = Option("lama"),
device: Device = Option(Device.cpu),
image: Path = Option(..., help="Image folders or file path"),
mask: Path = Option(
...,
help="Mask folders or file path. "
"If it is a directory, the mask images in the directory should have the same name as the original image."
"If it is a file, all images will use this mask."
"Mask will automatically resize to the same size as the original image.",
),
output: Path = Option(..., help="Output directory or file path"),
config: Path = Option(
None, help="Config file path. You can use dump command to create a base config."
),
concat: bool = Option(
False, help="Concat original image, mask and output images into one image"
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR,
help=MODEL_DIR_HELP,
file_okay=False,
callback=setup_model_dir,
),
):
from inpaint.download import cli_download_model, scan_models
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.info(f"{model} not found in {model_dir}, try to downloading")
cli_download_model(model)
from inpaint.batch_processing import batch_inpaint
batch_inpaint(model, device, image, mask, output, config, concat)
@typer_app.command(help="Start IOPaint server")
@use_json_config()
def start(
host: str = Option("127.0.0.1"),
port: int = Option(8080),
inbrowser: bool = Option(False, help=INBROWSER_HELP),
model: str = Option(
DEFAULT_MODEL,
help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR,
help=MODEL_DIR_HELP,
dir_okay=True,
file_okay=False,
callback=setup_model_dir,
),
low_mem: bool = Option(False, help=LOW_MEM_HELP),
no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
input: Optional[Path] = Option(None, help=INPUT_HELP),
mask_dir: Optional[Path] = Option(
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
output_dir: Optional[Path] = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
enable_gfpgan: bool = Option(False),
gfpgan_device: Device = Option(Device.cpu),
enable_restoreformer: bool = Option(False),
restoreformer_device: Device = Option(Device.cpu),
):
dump_environment_info()
device = check_device(device)
if input and not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit(-1)
if mask_dir and not mask_dir.exists():
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
exit(-1)
if input and input.is_dir() and not output_dir:
logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
exit(-1)
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will be saved to {output_dir}")
if not output_dir.exists():
logger.info(f"Create output directory {output_dir}")
output_dir.mkdir(parents=True)
if mask_dir:
mask_dir = mask_dir.expanduser().absolute()
model_dir = model_dir.expanduser().absolute()
if local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
from inpaint.download import cli_download_model, scan_models
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.info(f"{model} not found in {model_dir}, try to downloading")
cli_download_model(model)
from inpaint.api import Api
from inpaint.schema import ApiConfig
@asynccontextmanager
async def lifespan(app: FastAPI):
if inbrowser:
webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
yield
app = FastAPI(lifespan=lifespan)
api_config = ApiConfig(
host=host,
port=port,
inbrowser=inbrowser,
model=model,
no_half=no_half,
low_mem=low_mem,
cpu_offload=cpu_offload,
disable_nsfw_checker=disable_nsfw_checker,
local_files_only=local_files_only,
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
device=device,
input=input,
mask_dir=mask_dir,
output_dir=output_dir,
quality=quality,
enable_interactive_seg=enable_interactive_seg,
interactive_seg_model=interactive_seg_model,
interactive_seg_device=interactive_seg_device,
enable_remove_bg=enable_remove_bg,
remove_bg_model=remove_bg_model,
enable_anime_seg=enable_anime_seg,
enable_realesrgan=enable_realesrgan,
realesrgan_device=realesrgan_device,
realesrgan_model=realesrgan_model,
enable_gfpgan=enable_gfpgan,
gfpgan_device=gfpgan_device,
enable_restoreformer=enable_restoreformer,
restoreformer_device=restoreformer_device,
)
print(api_config.model_dump_json(indent=4))
api = Api(app, api_config)
api.launch()
@typer_app.command(help="Start IOPaint web config page")
def start_web_config(
config_file: Path = Option("config.json"),
):
dump_environment_info()
from inpaint.web_config import main
main(config_file)

128
inpaint/const.py Normal file
View File

@ -0,0 +1,128 @@
import os
from typing import List
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
ANYTEXT_NAME = "Sanster/AnyText"
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
MPS_UNSUPPORT_MODELS = [
"lama",
"ldm",
"zits",
"mat",
"fcf",
"cv2",
"manga",
]
DEFAULT_MODEL = "lama"
AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
DIFFUSION_MODELS = [
"runwayml/stable-diffusion-inpainting",
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
"redstonehero/dreamshaper-inpainting",
"Sanster/anything-4.0-inpainting",
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"Fantasy-Studio/Paint-by-Example",
POWERPAINT_NAME,
ANYTEXT_NAME,
]
NO_HALF_HELP = """
Using full precision(fp32) model.
If your diffusion model generate result is always black or green, use this argument.
"""
CPU_OFFLOAD_HELP = """
Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
"""
LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory."
DISABLE_NSFW_HELP = """
Disable NSFW checker for diffusion model.
"""
CPU_TEXTENCODER_HELP = """
Run diffusion models text encoder on CPU to reduce vRAM usage.
"""
SD_CONTROLNET_CHOICES: List[str] = [
"lllyasviel/control_v11p_sd15_canny",
# "lllyasviel/control_v11p_sd15_seg",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_inpaint",
"lllyasviel/control_v11f1p_sd15_depth",
]
SD_BRUSHNET_CHOICES: List[str] = [
"Sanster/brushnet_random_mask",
"Sanster/brushnet_segmentation_mask",
]
SD2_CONTROLNET_CHOICES = [
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-openpose-diffusers",
]
SDXL_CONTROLNET_CHOICES = [
"thibaud/controlnet-openpose-sdxl-1.0",
"destitech/controlnet-inpaint-dreamer-sdxl",
"diffusers/controlnet-canny-sdxl-1.0",
"diffusers/controlnet-canny-sdxl-1.0-mid",
"diffusers/controlnet-canny-sdxl-1.0-small",
"diffusers/controlnet-depth-sdxl-1.0",
"diffusers/controlnet-depth-sdxl-1.0-mid",
"diffusers/controlnet-depth-sdxl-1.0-small",
]
LOCAL_FILES_ONLY_HELP = """
When loading diffusion models, using local files only, not connect to HuggingFace server.
"""
DEFAULT_MODEL_DIR = os.path.abspath(
os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
)
MODEL_DIR_HELP = f"""
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
"""
OUTPUT_DIR_HELP = """
Result images will be saved to output directory automatically.
"""
MASK_DIR_HELP = """
You can view masks in FileManager
"""
INPUT_HELP = """
If input is image, it will be loaded by default.
If input is directory, you can browse and select image in file manager.
"""
GUI_HELP = """
Launch Lama Cleaner as desktop app
"""
QUALITY_HELP = """
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
"""
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
REALESRGAN_HELP = "Enable realesrgan super resolution"
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser"

313
inpaint/download.py Normal file
View File

@ -0,0 +1,313 @@
import glob
import json
import os
from functools import lru_cache
from typing import List, Optional
from inpaint.schema import ModelType, ModelInfo
from loguru import logger
from pathlib import Path
from inpaint.const import (
DEFAULT_MODEL_DIR,
DIFFUSERS_SD_CLASS_NAME,
DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
ANYTEXT_NAME,
)
from inpaint.model.original_sd_configs import get_config_files
def cli_download_model(model: str):
from inpaint.model import models
from inpaint.model.utils import handle_from_pretrained_exceptions
if model in models and models[model].is_erase_model:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info("Done.")
elif model == ANYTEXT_NAME:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info("Done.")
else:
logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline
downloaded_path = handle_from_pretrained_exceptions(
DiffusionPipeline.download,
pretrained_model_name=model,
variant="fp16",
resume_download=True,
)
logger.info(f"Done. Downloaded to {downloaded_path}")
def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/")
@lru_cache(maxsize=512)
def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
if "inpaint" in Path(model_abs_path).name.lower():
model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
# load once to check num_in_channels
from diffusers import StableDiffusionInpaintPipeline
try:
StableDiffusionInpaintPipeline.from_single_file(
model_abs_path,
load_safety_checker=False,
num_in_channels=9,
original_config_file=get_config_files()['v1']
)
model_type = ModelType.DIFFUSERS_SD_INPAINT
except ValueError as e:
if "[320, 4, 3, 3]" in str(e):
model_type = ModelType.DIFFUSERS_SD
else:
logger.info(f"Ignore non sdxl file: {model_abs_path}")
return
except Exception as e:
logger.error(f"Failed to load {model_abs_path}: {e}")
return
return model_type
@lru_cache()
def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
if "inpaint" in model_abs_path:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
# load once to check num_in_channels
from diffusers import StableDiffusionXLInpaintPipeline
try:
model = StableDiffusionXLInpaintPipeline.from_single_file(
model_abs_path,
load_safety_checker=False,
num_in_channels=9,
original_config_file=get_config_files()['xl'],
)
if model.unet.config.in_channels == 9:
# https://github.com/huggingface/diffusers/issues/6610
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SDXL
except ValueError as e:
if "[320, 4, 3, 3]" in str(e):
model_type = ModelType.DIFFUSERS_SDXL
else:
logger.info(f"Ignore non sdxl file: {model_abs_path}")
return
except Exception as e:
logger.error(f"Failed to load {model_abs_path}: {e}")
return
return model_type
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
stable_diffusion_dir = cache_dir / "stable_diffusion"
cache_file = stable_diffusion_dir / "iopaint_cache.json"
model_type_cache = {}
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
model_type_cache = json.load(f)
assert isinstance(model_type_cache, dict)
except:
pass
res = []
for it in stable_diffusion_dir.glob("*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
model_abs_path = str(it.absolute())
model_type = model_type_cache.get(it.name)
if model_type is None:
model_type = get_sd_model_type(model_abs_path)
if model_type is None:
continue
model_type_cache[it.name] = model_type
res.append(
ModelInfo(
name=it.name,
path=model_abs_path,
model_type=model_type,
is_single_file_diffusers=True,
)
)
if stable_diffusion_dir.exists():
with open(cache_file, "w", encoding="utf-8") as fw:
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
sdxl_model_type_cache = {}
if sdxl_cache_file.exists():
try:
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
sdxl_model_type_cache = json.load(f)
assert isinstance(sdxl_model_type_cache, dict)
except:
pass
for it in stable_diffusion_xl_dir.glob("*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
model_abs_path = str(it.absolute())
model_type = sdxl_model_type_cache.get(it.name)
if model_type is None:
model_type = get_sdxl_model_type(model_abs_path)
if model_type is None:
continue
sdxl_model_type_cache[it.name] = model_type
if stable_diffusion_xl_dir.exists():
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
res.append(
ModelInfo(
name=it.name,
path=model_abs_path,
model_type=model_type,
is_single_file_diffusers=True,
)
)
return res
def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
res = []
from inpaint.model import models
# logger.info(f"Scanning inpaint models in {model_dir}")
for name, m in models.items():
if m.is_erase_model and m.is_downloaded():
res.append(
ModelInfo(
name=name,
path=name,
model_type=ModelType.INPAINT,
)
)
return res
def scan_diffusers_models() -> List[ModelInfo]:
from huggingface_hub.constants import HF_HUB_CACHE
available_models = []
cache_dir = Path(HF_HUB_CACHE)
# logger.info(f"Scanning diffusers models in {cache_dir}")
diffusers_model_names = []
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
for it in model_index_files:
it = Path(it)
with open(it, "r", encoding="utf-8") as f:
try:
data = json.load(f)
except:
continue
_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
if "PowerPaint" in name:
model_type = ModelType.DIFFUSERS_OTHER
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
elif _class_name in [
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
"KandinskyV22InpaintPipeline",
"AnyText",
]:
model_type = ModelType.DIFFUSERS_OTHER
else:
continue
diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=name,
model_type=model_type,
)
)
return available_models
def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
available_models = []
diffusers_model_names = []
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
for it in model_index_files:
it = Path(it)
with open(it, "r", encoding="utf-8") as f:
try:
data = json.load(f)
except:
logger.error(
f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
)
continue
_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.name)
if name in diffusers_model_names:
continue
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
continue
diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=str(it.parent.absolute()),
model_type=model_type,
)
)
return available_models
def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
available_models = []
stable_diffusion_dir = cache_dir / "stable_diffusion"
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
return available_models
def scan_models() -> List[ModelInfo]:
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
available_models = []
available_models.extend(scan_inpaint_models(model_dir))
available_models.extend(scan_single_file_diffusion_models(model_dir))
available_models.extend(scan_diffusers_models())
available_models.extend(scan_converted_diffusers_models(model_dir))
return available_models

View File

@ -0,0 +1 @@
from .file_manager import FileManager

View File

@ -0,0 +1,218 @@
import os
from io import BytesIO
from pathlib import Path
from typing import List
from PIL import Image, ImageOps, PngImagePlugin
from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse
from ..schema import MediasResponse, MediaTab
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
from .storage_backends import FilesystemStorageBackend
from .utils import aspect_to_string, generate_filename, glob_img
class FileManager:
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
self.app = app
self.input_dir: Path = input_dir
self.mask_dir: Path = mask_dir
self.output_dir: Path = output_dir
self.image_dir_filenames = []
self.output_dir_filenames = []
if not self.thumbnail_directory.exists():
self.thumbnail_directory.mkdir(parents=True)
# fmt: off
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
# fmt: on
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
img_dir = self._get_dir(tab)
return self._media_names(img_dir)
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
file_path = self._get_file(tab, filename)
return FileResponse(file_path, media_type="image/png")
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
def api_media_thumbnail_file(
self, tab: MediaTab, filename: str, width: int, height: int
) -> FileResponse:
img_dir = self._get_dir(tab)
thumb_filename, (width, height) = self.get_thumbnail(
img_dir, filename, width=width, height=height
)
thumbnail_filepath = self.thumbnail_directory / thumb_filename
return FileResponse(
thumbnail_filepath,
headers={
"X-Width": str(width),
"X-Height": str(height),
},
media_type="image/jpeg",
)
def _get_dir(self, tab: MediaTab) -> Path:
if tab == "input":
return self.input_dir
elif tab == "output":
return self.output_dir
elif tab == "mask":
return self.mask_dir
else:
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
def _get_file(self, tab: MediaTab, filename: str) -> Path:
file_path = self._get_dir(tab) / filename
if not file_path.exists():
raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
return file_path
@property
def thumbnail_directory(self) -> Path:
return self.output_dir / "thumbnails"
@staticmethod
def _media_names(directory: Path) -> List[MediasResponse]:
names = sorted([it.name for it in glob_img(directory)])
res = []
for name in names:
path = os.path.join(directory, name)
img = Image.open(path)
res.append(
MediasResponse(
name=name,
height=img.height,
width=img.width,
ctime=os.path.getctime(path),
mtime=os.path.getmtime(path),
)
)
return res
def get_thumbnail(
self, directory: Path, original_filename: str, width, height, **options
):
directory = Path(directory)
storage = FilesystemStorageBackend(self.app)
crop = options.get("crop", "fit")
background = options.get("background")
quality = options.get("quality", 90)
original_path, original_filename = os.path.split(original_filename)
original_filepath = os.path.join(directory, original_path, original_filename)
image = Image.open(BytesIO(storage.read(original_filepath)))
# keep ratio resize
if not width and not height:
width = 256
if width != 0:
height = int(image.height * width / image.width)
else:
width = int(image.width * height / image.height)
thumbnail_size = (width, height)
thumbnail_filename = generate_filename(
directory,
original_filename,
aspect_to_string(thumbnail_size),
crop,
background,
quality,
)
thumbnail_filepath = os.path.join(
self.thumbnail_directory, original_path, thumbnail_filename
)
if storage.exists(thumbnail_filepath):
return thumbnail_filepath, (width, height)
try:
image.load()
except (IOError, OSError):
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
return thumbnail_filepath, (width, height)
# get original image format
options["format"] = options.get("format", image.format)
image = self._create_thumbnail(
image, thumbnail_size, crop, background=background
)
raw_data = self.get_raw_data(image, **options)
storage.save(thumbnail_filepath, raw_data)
return thumbnail_filepath, (width, height)
def get_raw_data(self, image, **options):
data = {
"format": self._get_format(image, **options),
"quality": options.get("quality", 90),
}
_file = BytesIO()
image.save(_file, **data)
return _file.getvalue()
@staticmethod
def colormode(image, colormode="RGB"):
if colormode == "RGB" or colormode == "RGBA":
if image.mode == "RGBA":
return image
if image.mode == "LA":
return image.convert("RGBA")
return image.convert(colormode)
if colormode == "GRAY":
return image.convert("L")
return image.convert(colormode)
@staticmethod
def background(original_image, color=0xFF):
size = (max(original_image.size),) * 2
image = Image.new("L", size, color)
image.paste(
original_image,
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
)
return image
def _get_format(self, image, **options):
if options.get("format"):
return options.get("format")
if image.format:
return image.format
return "JPEG"
def _create_thumbnail(self, image, size, crop="fit", background=None):
try:
resample = Image.Resampling.LANCZOS
except AttributeError: # pylint: disable=raise-missing-from
resample = Image.ANTIALIAS
if crop == "fit":
image = ImageOps.fit(image, size, resample)
else:
image = image.copy()
image.thumbnail(size, resample=resample)
if background is not None:
image = self.background(image)
image = self.colormode(image)
return image

View File

@ -0,0 +1,46 @@
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
import errno
import os
from abc import ABC, abstractmethod
class BaseStorageBackend(ABC):
def __init__(self, app=None):
self.app = app
@abstractmethod
def read(self, filepath, mode="rb", **kwargs):
raise NotImplementedError
@abstractmethod
def exists(self, filepath):
raise NotImplementedError
@abstractmethod
def save(self, filepath, data):
raise NotImplementedError
class FilesystemStorageBackend(BaseStorageBackend):
def read(self, filepath, mode="rb", **kwargs):
with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
return f.read()
def exists(self, filepath):
return os.path.exists(filepath)
def save(self, filepath, data):
directory = os.path.dirname(filepath)
if not os.path.exists(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
if not os.path.isdir(directory):
raise IOError("{} is not a directory".format(directory))
with open(filepath, "wb") as f:
f.write(data)

View File

@ -0,0 +1,65 @@
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
import hashlib
from pathlib import Path
from typing import Union
def generate_filename(directory: Path, original_filename, *options) -> str:
text = str(directory.absolute()) + original_filename
for v in options:
text += "%s" % v
md5_hash = hashlib.md5()
md5_hash.update(text.encode("utf-8"))
return md5_hash.hexdigest() + ".jpg"
def parse_size(size):
if isinstance(size, int):
# If the size parameter is a single number, assume square aspect.
return [size, size]
if isinstance(size, (tuple, list)):
if len(size) == 1:
# If single value tuple/list is provided, exand it to two elements
return size + type(size)(size)
return size
try:
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
except ValueError:
raise ValueError( # pylint: disable=raise-missing-from
"Bad thumbnail size format. Valid format is INTxINT."
)
if len(thumbnail_size) == 1:
# If the size parameter only contains a single integer, assume square aspect.
thumbnail_size.append(thumbnail_size[0])
return thumbnail_size
def aspect_to_string(size):
if isinstance(size, str):
return size
return "x".join(map(str, size))
IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
def glob_img(p: Union[Path, str], recursive: bool = False):
p = Path(p)
if p.is_file() and p.suffix in IMG_SUFFIX:
yield p
else:
if recursive:
files = Path(p).glob("**/*.*")
else:
files = Path(p).glob("*.*")
for it in files:
if it.suffix not in IMG_SUFFIX:
continue
yield it

408
inpaint/helper.py Normal file
View File

@ -0,0 +1,408 @@
import base64
import imghdr
import io
import os
import sys
from typing import List, Optional, Dict, Tuple
from urllib.parse import urlparse
import cv2
from PIL import Image, ImageOps, PngImagePlugin
import numpy as np
import torch
from inpaint.const import MPS_UNSUPPORT_MODELS
from loguru import logger
from torch.hub import download_url_to_file, get_dir
import hashlib
def md5sum(filename):
md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
md5.update(chunk)
return md5.hexdigest()
def switch_mps_device(model_name, device):
if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
logger.info(f"{model_name} not support mps, switch to cpu")
return torch.device("cpu")
return device
def get_cache_path_by_url(url):
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
return cached_file
def download_model(url, model_md5: str = None):
if os.path.exists(url):
cached_file = url
else:
cached_file = get_cache_path_by_url(url)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
if model_md5:
_md5 = md5sum(cached_file)
if model_md5 == _md5:
logger.info(f"Download model success, md5: {_md5}")
else:
try:
os.remove(cached_file)
logger.error(
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
except:
logger.error(
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint."
)
exit(-1)
return cached_file
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def handle_error(model_path, model_md5, e):
_md5 = md5sum(model_path)
if _md5 != model_md5:
try:
os.remove(model_path)
logger.error(
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
except:
logger.error(
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint."
)
else:
logger.error(
f"Failed to load model {model_path},"
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
)
exit(-1)
def load_jit_model(url_or_path, device, model_md5: str):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path, model_md5)
logger.info(f"Loading model from: {model_path}")
try:
model = torch.jit.load(model_path, map_location="cpu").to(device)
except Exception as e:
handle_error(model_path, model_md5, e)
model.eval()
return model
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path, model_md5)
try:
logger.info(f"Loading model from: {model_path}")
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(device)
except Exception as e:
handle_error(model_path, model_md5, e)
model.eval()
return model
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(
f".{ext}",
image_numpy,
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
)[1]
image_bytes = data.tobytes()
return image_bytes
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
with io.BytesIO() as output:
kwargs = {k: v for k, v in infos.items() if v is not None}
if ext == "jpg":
ext = "jpeg"
if "png" == ext.lower() and "parameters" in kwargs:
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text("parameters", kwargs["parameters"])
kwargs["pnginfo"] = pnginfo_data
pil_img.save(output, format=ext, quality=quality, **kwargs)
image_bytes = output.getvalue()
return image_bytes
def load_img(img_bytes, gray: bool = False, return_info: bool = False):
alpha_channel = None
image = Image.open(io.BytesIO(img_bytes))
if return_info:
infos = image.info
try:
image = ImageOps.exif_transpose(image)
except:
pass
if gray:
image = image.convert("L")
np_img = np.array(image)
else:
if image.mode == "RGBA":
np_img = np.array(image)
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
image = image.convert("RGB")
np_img = np.array(image)
if return_info:
return np_img, alpha_channel, infos
return np_img, alpha_channel
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def resize_max_size(
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
) -> np.ndarray:
# Resize image's longer size to size_limit if longer size larger than size_limit
h, w = np_img.shape[:2]
if max(h, w) > size_limit:
ratio = size_limit / max(h, w)
new_w = int(w * ratio + 0.5)
new_h = int(h * ratio + 0.5)
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
else:
return np_img
def pad_img_to_modulo(
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
):
"""
Args:
img: [H, W, C]
mod:
square: 是否为正方形
min_size:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
if min_size is not None:
assert min_size % mod == 0
out_width = max(min_size, out_width)
out_height = max(min_size, out_height)
if square:
max_size = max(out_height, out_width)
out_height = max_size
out_width = max_size
return np.pad(
img,
((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric",
)
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (h, w, 1) 0~255
Returns:
"""
height, width = mask.shape[:2]
_, thresh = cv2.threshold(mask, 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
box = np.array([x, y, x + w, y + h]).astype(int)
box[::2] = np.clip(box[::2], 0, width)
box[1::2] = np.clip(box[1::2], 0, height)
boxes.append(box)
return boxes
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (h, w) 0~255
Returns:
"""
_, thresh = cv2.threshold(mask, 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_area = 0
max_index = -1
for i, cnt in enumerate(contours):
area = cv2.contourArea(cnt)
if area > max_area:
max_area = area
max_index = i
if max_index != -1:
new_mask = np.zeros_like(mask)
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
else:
return mask
def is_mac():
return sys.platform == "darwin"
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def decode_base64_to_image(
encoding: str, gray=False
) -> Tuple[np.array, Optional[np.array], Dict]:
if encoding.startswith("data:image/") or encoding.startswith(
"data:application/octet-stream;base64,"
):
encoding = encoding.split(";")[1].split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
alpha_channel = None
try:
image = ImageOps.exif_transpose(image)
except:
pass
# exif_transpose will remove exif rotate infowe must call image.info after exif_transpose
infos = image.info
if gray:
image = image.convert("L")
np_img = np.array(image)
else:
if image.mode == "RGBA":
np_img = np.array(image)
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
image = image.convert("RGB")
np_img = np.array(image)
return np_img, alpha_channel, infos
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
img_bytes = pil_to_bytes(
image,
"png",
quality=quality,
infos=infos,
)
return base64.b64encode(img_bytes)
def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
if alpha_channel is not None:
if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
)
rgb_np_img = np.concatenate(
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
return rgb_np_img
def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
# fronted brush color "ffcc00bb"
# kernel_size = kernel_size*2+1
mask[mask >= 127] = 255
mask[mask < 127] = 0
if operate == "reverse":
mask = 255 - mask
else:
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
)
if operate == "expand":
mask = cv2.dilate(
mask,
kernel,
iterations=1,
)
else:
mask = cv2.erode(
mask,
kernel,
iterations=1,
)
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask
def gen_frontend_mask(bgr_or_gray_mask):
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
# fronted brush color "ffcc00bb"
# TODO: how to set kernel size?
kernel_size = 9
bgr_or_gray_mask = cv2.dilate(
bgr_or_gray_mask,
np.ones((kernel_size, kernel_size), np.uint8),
iterations=1,
)
res_mask = np.zeros(
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
)
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask

10
inpaint/installer.py Normal file
View File

@ -0,0 +1,10 @@
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
def install_plugins_package():
install("rembg")

37
inpaint/model/__init__.py Normal file
View File

@ -0,0 +1,37 @@
from .anytext.anytext_model import AnyText
from .controlnet import ControlNet
from .fcf import FcF
from .instruct_pix2pix import InstructPix2Pix
from .kandinsky import Kandinsky22
from .lama import LaMa
from .ldm import LDM
from .manga import Manga
from .mat import MAT
from .mi_gan import MIGAN
from .opencv2 import OpenCV2
from .paint_by_example import PaintByExample
from .power_paint.power_paint import PowerPaint
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
from .sdxl import SDXL
from .zits import ZITS
models = {
LaMa.name: LaMa,
LDM.name: LDM,
ZITS.name: ZITS,
MAT.name: MAT,
FcF.name: FcF,
OpenCV2.name: OpenCV2,
Manga.name: Manga,
MIGAN.name: MIGAN,
SD15.name: SD15,
Anything4.name: Anything4,
RealisticVision14.name: RealisticVision14,
SD2.name: SD2,
PaintByExample.name: PaintByExample,
InstructPix2Pix.name: InstructPix2Pix,
Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL,
PowerPaint.name: PowerPaint,
AnyText.name: AnyText,
}

View File

View File

@ -0,0 +1,73 @@
import torch
from huggingface_hub import hf_hub_download
from inpaint.const import ANYTEXT_NAME
from inpaint.model.anytext.anytext_pipeline import AnyTextPipeline
from inpaint.model.base import DiffusionInpaintModel
from inpaint.model.utils import get_torch_dtype, is_local_files_only
from inpaint.schema import InpaintRequest
class AnyText(DiffusionInpaintModel):
name = ANYTEXT_NAME
pad_mod = 64
is_erase_model = False
@staticmethod
def download(local_files_only=False):
hf_hub_download(
repo_id=ANYTEXT_NAME,
filename="model_index.json",
local_files_only=local_files_only,
)
ckpt_path = hf_hub_download(
repo_id=ANYTEXT_NAME,
filename="pytorch_model.fp16.safetensors",
local_files_only=local_files_only,
)
font_path = hf_hub_download(
repo_id=ANYTEXT_NAME,
filename="SourceHanSansSC-Medium.otf",
local_files_only=local_files_only,
)
return ckpt_path, font_path
def init_model(self, device, **kwargs):
local_files_only = is_local_files_only(**kwargs)
ckpt_path, font_path = self.download(local_files_only)
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
self.model = AnyTextPipeline(
ckpt_path=ckpt_path,
font_path=font_path,
device=device,
use_fp16=torch_dtype == torch.float16,
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to inpainting
return: BGR IMAGE
"""
height, width = image.shape[:2]
mask = mask.astype("float32") / 255.0
masked_image = image * (1 - mask)
# list of rgb ndarray
results, rtn_code, rtn_warning = self.model(
image=image,
masked_image=masked_image,
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.sd_steps,
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
height=height,
width=width,
seed=config.sd_seed,
sort_priority="y",
callback=self.callback
)
inpainted_rgb_image = results[0][..., ::-1]
return inpainted_rgb_image

View File

@ -0,0 +1,403 @@
"""
AnyText: Multilingual Visual Text Generation And Editing
Paper: https://arxiv.org/abs/2311.03054
Code: https://github.com/tyxsspa/AnyText
Copyright (c) Alibaba, Inc. and its affiliates.
"""
import os
from pathlib import Path
from inpaint.model.utils import set_seed
from safetensors.torch import load_file
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import torch
import re
import numpy as np
import cv2
import einops
from PIL import ImageFont
from inpaint.model.anytext.cldm.model import create_model, load_state_dict
from inpaint.model.anytext.cldm.ddim_hacked import DDIMSampler
from inpaint.model.anytext.utils import (
check_channels,
draw_glyph,
draw_glyph2,
)
BBOX_MAX_NUM = 8
PLACE_HOLDER = "*"
max_chars = 20
ANYTEXT_CFG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
)
def check_limits(tensor):
float16_min = torch.finfo(torch.float16).min
float16_max = torch.finfo(torch.float16).max
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
is_below_min = (tensor < float16_min).any()
is_above_max = (tensor > float16_max).any()
return is_below_min or is_above_max
class AnyTextPipeline:
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
self.cfg_path = ANYTEXT_CFG
self.font_path = font_path
self.use_fp16 = use_fp16
self.device = device
self.font = ImageFont.truetype(font_path, size=60)
self.model = create_model(
self.cfg_path,
device=self.device,
use_fp16=self.use_fp16,
)
if self.use_fp16:
self.model = self.model.half()
if Path(ckpt_path).suffix == ".safetensors":
state_dict = load_file(ckpt_path, device="cpu")
else:
state_dict = load_state_dict(ckpt_path, location="cpu")
self.model.load_state_dict(state_dict, strict=False)
self.model = self.model.eval().to(self.device)
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
def __call__(
self,
prompt: str,
negative_prompt: str,
image: np.ndarray,
masked_image: np.ndarray,
num_inference_steps: int,
strength: float,
guidance_scale: float,
height: int,
width: int,
seed: int,
sort_priority: str = "y",
callback=None,
):
"""
Args:
prompt:
negative_prompt:
image:
masked_image:
num_inference_steps:
strength:
guidance_scale:
height:
width:
seed:
sort_priority: x: left-right, y: top-down
Returns:
result: list of images in numpy.ndarray format
rst_code: 0: normal -1: error 1:warning
rst_info: string of error or warning
"""
set_seed(seed)
str_warning = ""
mode = "text-editing"
revise_pos = False
img_count = 1
ddim_steps = num_inference_steps
w = width
h = height
strength = strength
cfg_scale = guidance_scale
eta = 0.0
prompt, texts = self.modify_prompt(prompt)
if prompt is None and texts is None:
return (
None,
-1,
"You have input Chinese prompt but the translator is not loaded!",
"",
)
n_lines = len(texts)
if mode in ["text-generation", "gen"]:
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
elif mode in ["text-editing", "edit"]:
if masked_image is None or image is None:
return (
None,
-1,
"Reference image and position image are needed for text editing!",
"",
)
if isinstance(image, str):
image = cv2.imread(image)[..., ::-1]
assert image is not None, f"Can't read ori_image image from{image}!"
elif isinstance(image, torch.Tensor):
image = image.cpu().numpy()
else:
assert isinstance(
image, np.ndarray
), f"Unknown format of ori_image: {type(image)}"
edit_image = image.clip(1, 255) # for mask reason
edit_image = check_channels(edit_image)
# edit_image = resize_image(
# edit_image, max_length=768
# ) # make w h multiple of 64, resize if w or h > max_length
h, w = edit_image.shape[:2] # change h, w by input ref_img
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
if masked_image is None:
pos_imgs = np.zeros((w, h, 1))
if isinstance(masked_image, str):
masked_image = cv2.imread(masked_image)[..., ::-1]
assert (
masked_image is not None
), f"Can't read draw_pos image from{masked_image}!"
pos_imgs = 255 - masked_image
elif isinstance(masked_image, torch.Tensor):
pos_imgs = masked_image.cpu().numpy()
else:
assert isinstance(
masked_image, np.ndarray
), f"Unknown format of draw_pos: {type(masked_image)}"
pos_imgs = 255 - masked_image
pos_imgs = pos_imgs[..., 0:1]
pos_imgs = cv2.convertScaleAbs(pos_imgs)
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
# seprate pos_imgs
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
if len(pos_imgs) == 0:
pos_imgs = [np.zeros((h, w, 1))]
if len(pos_imgs) < n_lines:
if n_lines == 1 and texts[0] == " ":
pass # text-to-image without text
else:
raise RuntimeError(
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
)
elif len(pos_imgs) > n_lines:
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
# get pre_pos, poly_list, hint that needed for anytext
pre_pos = []
poly_list = []
for input_pos in pos_imgs:
if input_pos.mean() != 0:
input_pos = (
input_pos[..., np.newaxis]
if len(input_pos.shape) == 2
else input_pos
)
poly, pos_img = self.find_polygon(input_pos)
pre_pos += [pos_img / 255.0]
poly_list += [poly]
else:
pre_pos += [np.zeros((h, w, 1))]
poly_list += [None]
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
# prepare info dict
info = {}
info["glyphs"] = []
info["gly_line"] = []
info["positions"] = []
info["n_lines"] = [len(texts)] * img_count
gly_pos_imgs = []
for i in range(len(texts)):
text = texts[i]
if len(text) > max_chars:
str_warning = (
f'"{text}" length > max_chars: {max_chars}, will be cut off...'
)
text = text[:max_chars]
gly_scale = 2
if pre_pos[i].mean() != 0:
gly_line = draw_glyph(self.font, text)
glyphs = draw_glyph2(
self.font,
text,
poly_list[i],
scale=gly_scale,
width=w,
height=h,
add_space=False,
)
gly_pos_img = cv2.drawContours(
glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
)
if revise_pos:
resize_gly = cv2.resize(
glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
)
new_pos = cv2.morphologyEx(
(resize_gly * 255).astype(np.uint8),
cv2.MORPH_CLOSE,
kernel=np.ones(
(resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
dtype=np.uint8,
),
iterations=1,
)
new_pos = (
new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
)
contours, _ = cv2.findContours(
new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
if len(contours) != 1:
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
else:
rect = cv2.minAreaRect(contours[0])
poly = np.int0(cv2.boxPoints(rect))
pre_pos[i] = (
cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
)
gly_pos_img = cv2.drawContours(
glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
)
gly_pos_imgs += [gly_pos_img] # for show
else:
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
gly_line = np.zeros((80, 512, 1))
gly_pos_imgs += [
np.zeros((h * gly_scale, w * gly_scale, 1))
] # for show
pos = pre_pos[i]
info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
info["positions"] += [self.arr2tensor(pos, img_count)]
# get masked_x
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
masked_img = np.transpose(masked_img, (2, 0, 1))
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
if self.use_fp16:
masked_img = masked_img.half()
encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
if self.use_fp16:
masked_x = masked_x.half()
info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
hint = self.arr2tensor(np_hint, img_count)
cond = self.model.get_learned_conditioning(
dict(
c_concat=[hint],
c_crossattn=[[prompt] * img_count],
text_info=info,
)
)
un_cond = self.model.get_learned_conditioning(
dict(
c_concat=[hint],
c_crossattn=[[negative_prompt] * img_count],
text_info=info,
)
)
shape = (4, h // 8, w // 8)
self.model.control_scales = [strength] * 13
samples, intermediates = self.ddim_sampler.sample(
ddim_steps,
img_count,
shape,
cond,
verbose=False,
eta=eta,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=un_cond,
callback=callback
)
if self.use_fp16:
samples = samples.half()
x_samples = self.model.decode_first_stage(samples)
x_samples = (
(einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
results = [x_samples[i] for i in range(img_count)]
# if (
# mode == "edit" and False
# ): # replace backgound in text editing but not ideal yet
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
# if len(gly_pos_imgs) > 0 and show_debug:
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
# glyph_img = np.sum(glyph_bs, axis=2) * 255
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
# results += [np.repeat(glyph_img, 3, axis=2)]
rst_code = 1 if str_warning else 0
return results, rst_code, str_warning
def modify_prompt(self, prompt):
prompt = prompt.replace("", '"')
prompt = prompt.replace("", '"')
p = '"(.*?)"'
strs = re.findall(p, prompt)
if len(strs) == 0:
strs = [" "]
else:
for s in strs:
prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
# if self.is_chinese(prompt):
# if self.trans_pipe is None:
# return None, None
# old_prompt = prompt
# prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
# print(f"Translate: {old_prompt} --> {prompt}")
return prompt, strs
# def is_chinese(self, text):
# text = checker._clean_text(text)
# for char in text:
# cp = ord(char)
# if checker._is_chinese_char(cp):
# return True
# return False
def separate_pos_imgs(self, img, sort_priority, gap=102):
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
components = []
for label in range(1, num_labels):
component = np.zeros_like(img)
component[labels == label] = 255
components.append((component, centroids[label]))
if sort_priority == "y":
fir, sec = 1, 0 # top-down first
elif sort_priority == "x":
fir, sec = 0, 1 # left-right first
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
sorted_components = [c[0] for c in components]
return sorted_components
def find_polygon(self, image, min_rect=False):
contours, hierarchy = cv2.findContours(
image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
max_contour = max(contours, key=cv2.contourArea) # get contour with max area
if min_rect:
# get minimum enclosing rectangle
rect = cv2.minAreaRect(max_contour)
poly = np.int0(cv2.boxPoints(rect))
else:
# get approximate polygon
epsilon = 0.01 * cv2.arcLength(max_contour, True)
poly = cv2.approxPolyDP(max_contour, epsilon, True)
n, _, xy = poly.shape
poly = poly.reshape(n, xy)
cv2.drawContours(image, [poly], -1, 255, -1)
return poly, image
def arr2tensor(self, arr, bs):
arr = np.transpose(arr, (2, 0, 1))
_arr = torch.from_numpy(arr.copy()).float().to(self.device)
if self.use_fp16:
_arr = _arr.half()
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
return _arr

View File

@ -0,0 +1,99 @@
model:
target: iopaint.model.anytext.cldm.cldm.ControlLDM
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "img"
cond_stage_key: "caption"
control_key: "hint"
glyph_key: "glyphs"
position_key: "positions"
image_size: 64
channels: 4
cond_stage_trainable: true # need be true when embedding_manager is valid
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
only_mid_control: False
loss_alpha: 0 # perceptual loss, 0.003
loss_beta: 0 # ctc loss
latin_weight: 1.0 # latin text line may need smaller weigth
with_step_weight: true
use_vae_upsample: true
embedding_manager_config:
target: iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
params:
valid: true # v6
emb_type: ocr # ocr, vit, conv
glyph_channels: 1
position_channels: 1
add_pos: false
placeholder_string: '*'
control_stage_config:
target: iopaint.model.anytext.cldm.cldm.ControlNet
params:
image_size: 32 # unused
in_channels: 4
model_channels: 320
glyph_channels: 1
position_channels: 1
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
unet_config:
target: iopaint.model.anytext.cldm.cldm.ControlledUnetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
params:
version: openai/clip-vit-large-patch14
use_vision: false # v6

View File

View File

@ -0,0 +1,630 @@
import os
from pathlib import Path
import einops
import torch
import torch as th
import torch.nn as nn
import copy
from easydict import EasyDict as edict
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from einops import rearrange, repeat
from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
from iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
from iopaint.model.anytext.ldm.util import log_txt_as_img, exists, instantiate_from_config
from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from .recognizer import TextRecognizer, create_predictor
CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
hs = []
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
if self.use_fp16:
t_emb = t_emb.half()
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for i, module in enumerate(self.output_blocks):
if only_mid_control or control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
glyph_channels,
position_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.use_fp16 = use_fp16
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self.glyph_block = TimestepEmbedSequential(
conv_nd(dims, glyph_channels, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
)
self.position_block = TimestepEmbedSequential(
conv_nd(dims, position_channels, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 64, 3, padding=1, stride=2),
nn.SiLU(),
)
self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
if self.use_fp16:
t_emb = t_emb.half()
emb = self.time_embed(t_emb)
# guided_hint from text_info
B, C, H, W = x.shape
glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
enc_glyph = self.glyph_block(glyphs, emb, context)
enc_pos = self.position_block(positions, emb, context)
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
class ControlLDM(LatentDiffusion):
def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
self.use_fp16 = kwargs.pop('use_fp16', False)
super().__init__(*args, **kwargs)
self.control_model = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.glyph_key = glyph_key
self.position_key = position_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
self.loss_alpha = loss_alpha
self.loss_beta = loss_beta
self.with_step_weight = with_step_weight
self.use_vae_upsample = use_vae_upsample
self.latin_weight = latin_weight
if embedding_manager_config is not None and embedding_manager_config.params.valid:
self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
for param in self.embedding_manager.embedding_parameters():
param.requires_grad = True
else:
self.embedding_manager = None
if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
if embedding_manager_config.params.emb_type == 'ocr':
self.text_predictor = create_predictor().eval()
args = edict()
args.rec_image_shape = "3, 48, 320"
args.rec_batch_num = 6
args.rec_char_dict_path = str(CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt")
args.use_fp16 = self.use_fp16
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
for param in self.text_predictor.parameters():
param.requires_grad = False
if self.embedding_manager:
self.embedding_manager.recog = self.cn_recognizer
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
if self.embedding_manager is None: # fill in full caption
self.fill_caption(batch)
x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
control = batch[self.control_key] # for log_images and loss_alpha, not real control
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
inv_mask = batch['inv_mask']
if bs is not None:
inv_mask = inv_mask[:bs]
inv_mask = inv_mask.to(self.device)
inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
glyphs = batch[self.glyph_key]
gly_line = batch['gly_line']
positions = batch[self.position_key]
n_lines = batch['n_lines']
language = batch['language']
texts = batch['texts']
assert len(glyphs) == len(positions)
for i in range(len(glyphs)):
if bs is not None:
glyphs[i] = glyphs[i][:bs]
gly_line[i] = gly_line[i][:bs]
positions[i] = positions[i][:bs]
n_lines = n_lines[:bs]
glyphs[i] = glyphs[i].to(self.device)
gly_line[i] = gly_line[i].to(self.device)
positions[i] = positions[i].to(self.device)
glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
info = {}
info['glyphs'] = glyphs
info['positions'] = positions
info['n_lines'] = n_lines
info['language'] = language
info['texts'] = texts
info['img'] = batch['img'] # nhwc, (-1,1)
info['masked_x'] = mx
info['gly_line'] = gly_line
info['inv_mask'] = inv_mask
return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
_cond = torch.cat(cond['c_crossattn'], 1)
_hint = torch.cat(cond['c_concat'], 1)
if self.use_fp16:
x_noisy = x_noisy.half()
control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
return eps
def instantiate_embedding_manager(self, config, embedder):
model = instantiate_from_config(config, embedder=embedder)
return model
@torch.no_grad()
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
if self.embedding_manager is not None and c['text_info'] is not None:
self.embedding_manager.encode_text(c['text_info'])
if isinstance(c, dict):
cond_txt = c['c_crossattn'][0]
else:
cond_txt = c
if self.embedding_manager is not None:
cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
else:
cond_txt = self.cond_stage_model.encode(cond_txt)
if isinstance(c, dict):
c['c_crossattn'][0] = cond_txt
else:
c = cond_txt
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def fill_caption(self, batch, place_holder='*'):
bs = len(batch['n_lines'])
cond_list = copy.deepcopy(batch[self.cond_stage_key])
for i in range(bs):
n_lines = batch['n_lines'][i]
if n_lines == 0:
continue
cur_cap = cond_list[i]
for j in range(n_lines):
r_txt = batch['texts'][j][i]
cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
cond_list[i] = cur_cap
batch[self.cond_stage_key] = cond_list
@torch.no_grad()
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c = self.get_input(batch, self.first_stage_key, bs=N)
if self.cond_stage_trainable:
with torch.no_grad():
c = self.get_learned_conditioning(c)
c_crossattn = c["c_crossattn"][0][:N]
c_cat = c["c_concat"][0][:N]
text_info = c["text_info"]
text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
text_info['positions'] = [i[:N] for i in text_info['positions']]
text_info['n_lines'] = text_info['n_lines'][:N]
text_info['masked_x'] = text_info['masked_x'][:N]
text_info['img'] = text_info['img'][:N]
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log["reconstruction"] = self.decode_first_stage(z)
log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
log["control"] = c_cat * 2.0 - 1.0
log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
# get glyph
glyph_bs = torch.stack(text_info['glyphs'])
glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
# fill caption
if not self.embedding_manager:
self.fill_caption(batch)
captions = batch[self.cond_stage_key]
log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N)
uc_cat = c_cat # torch.zeros_like(c_cat)
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
pred_x0 = False # wether log pred_x0
if pred_x0:
for idx in range(len(tmps['pred_x0'])):
pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (self.channels, h // 8, w // 8)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control_model.parameters())
if self.embedding_manager:
params += list(self.embedding_manager.embedding_parameters())
if not self.sd_locked:
# params += list(self.model.diffusion_model.input_blocks.parameters())
# params += list(self.model.diffusion_model.middle_block.parameters())
params += list(self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
if self.unlockKV:
nCount = 0
for name, param in self.model.diffusion_model.named_parameters():
if 'attn2.to_k' in name or 'attn2.to_v' in name:
params += [param]
nCount += 1
print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()

View File

@ -0,0 +1,486 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
def __init__(self, model, device, schedule="linear", **kwargs):
super().__init__()
self.device = device
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f"Data shape for DDIM sampling is {size}, eta {eta}")
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
ucg_schedule=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0 = outs
if callback:
callback(None, i, None, None)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (
model_t - model_uncond
)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", "not implemented"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(
self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
callback=None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
num_reference_steps = timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc="Encoding Image"):
t = torch.full(
(x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
)
if unconditional_guidance_scale == 1.0:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)),
torch.cat((t, t)),
torch.cat((unconditional_conditioning, c)),
),
2,
)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond
)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = (
alphas_next[i].sqrt()
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
* noise_pred
)
x_next = xt_weighted + weighted_noise_pred
if (
return_intermediates
and i % (num_steps // return_intermediates) == 0
and i < num_steps - 1
):
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback:
callback(i)
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
if return_intermediates:
out.update({"intermediates": intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if callback:
callback(i)
return x_dec

View File

@ -0,0 +1,165 @@
'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1]
return token
def get_clip_vision_emb(encoder, processor, img):
_img = img.repeat(1, 3, 1, 1)*255
inputs = processor(images=_img, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
outputs = encoder(**inputs)
emb = outputs.image_embeds
return emb
def get_recog_emb(encoder, img_list):
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
encoder.predictor.eval()
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
return preds_neck
def pad_H(x):
_, _, H, W = x.shape
p_top = (W - H) // 2
p_bot = W - H - p_top
return F.pad(x, (0, 0, p_top, p_bot))
class EncodeNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeNet, self).__init__()
chan = 16
n_layer = 4 # downsample
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
self.conv_list = nn.ModuleList([])
_c = chan
for i in range(n_layer):
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
_c *= 2
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.act = nn.SiLU()
def forward(self, x):
x = self.act(self.conv1(x))
for layer in self.conv_list:
x = self.act(layer(x))
x = self.act(self.conv2(x))
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
valid=True,
glyph_channels=20,
position_channels=1,
placeholder_string='*',
add_pos=False,
emb_type='ocr',
**kwargs
):
super().__init__()
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
token_dim = 768
if hasattr(embedder, 'vit'):
assert emb_type == 'vit'
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
self.get_recog_emb = None
else: # using LDM's BERT encoder
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
token_dim = 1280
self.token_dim = token_dim
self.emb_type = emb_type
self.add_pos = add_pos
if add_pos:
self.position_encoder = EncodeNet(position_channels, token_dim)
if emb_type == 'ocr':
self.proj = linear(40*64, token_dim)
if emb_type == 'conv':
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
self.placeholder_token = get_token_for_string(placeholder_string)
def encode_text(self, text_info):
if self.get_recog_emb is None and self.emb_type == 'ocr':
self.get_recog_emb = partial(get_recog_emb, self.recog)
gline_list = []
pos_list = []
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
for j in range(n_lines): # line
gline_list += [text_info['gly_line'][j][i:i+1]]
if self.add_pos:
pos_list += [text_info['positions'][j][i:i+1]]
if len(gline_list) > 0:
if self.emb_type == 'ocr':
recog_emb = self.get_recog_emb(gline_list)
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
elif self.emb_type == 'vit':
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
elif self.emb_type == 'conv':
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
if self.add_pos:
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
enc_glyph = enc_glyph+enc_pos
self.text_embs_all = []
n_idx = 0
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
text_embs = []
for j in range(n_lines): # line
text_embs += [enc_glyph[n_idx:n_idx+1]]
n_idx += 1
self.text_embs_all += [text_embs]
def forward(
self,
tokenized_text,
embedded_text,
):
b, device = tokenized_text.shape[0], tokenized_text.device
for i in range(b):
idx = tokenized_text[i] == self.placeholder_token.to(device)
if sum(idx) > 0:
if i >= len(self.text_embs_all):
print('truncation for log images...')
break
text_emb = torch.cat(self.text_embs_all[i], dim=0)
if sum(idx) != len(text_emb):
print('truncation for long caption...')
embedded_text[i][idx] = text_emb[:sum(idx)]
return embedded_text
def embedding_parameters(self):
return self.parameters()

View File

@ -0,0 +1,111 @@
import torch
import einops
import iopaint.model.anytext.ldm.modules.encoders.modules
import iopaint.model.anytext.ldm.modules.attention
from transformers import logging
from iopaint.model.anytext.ldm.modules.attention import default
def disable_verbosity():
logging.set_verbosity_error()
print('logging improved.')
return
def enable_sliced_attention():
iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
print('Enabled sliced_attention.')
return
def hack_everything(clip_skip=0):
disable_verbosity()
iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
print('Enabled clip hacks.')
return
# Written by Lvmin
def _hacked_clip_forward(self, text):
PAD = self.tokenizer.pad_token_id
EOS = self.tokenizer.eos_token_id
BOS = self.tokenizer.bos_token_id
def tokenize(t):
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
def transformer_encode(t):
if self.clip_skip > 1:
rt = self.transformer(input_ids=t, output_hidden_states=True)
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
else:
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
def split(x):
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
raw_tokens_list = tokenize(text)
tokens_list = []
for raw_tokens in raw_tokens_list:
raw_tokens_123 = split(raw_tokens)
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
tokens_list.append(raw_tokens_123)
tokens_list = torch.IntTensor(tokens_list).to(self.device)
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
y = transformer_encode(feed)
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
return z
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = 1
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range(0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i + att_step, :, :] = sim_buffer
del sim_buffer
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)

View File

@ -0,0 +1,40 @@
import os
import torch
from omegaconf import OmegaConf
from iopaint.model.anytext.ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get("state_dict", d)
def load_state_dict(ckpt_path, location="cpu"):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(
torch.load(ckpt_path, map_location=torch.device(location))
)
state_dict = get_state_dict(state_dict)
print(f"Loaded state_dict from [{ckpt_path}]")
return state_dict
def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
config = OmegaConf.load(config_path)
# if cond_stage_path:
# config.model.params.cond_stage_config.params.version = (
# cond_stage_path # use pre-downloaded ckpts, in case blocked
# )
config.model.params.cond_stage_config.params.device = str(device)
if use_fp16:
config.model.params.use_fp16 = True
config.model.params.control_stage_config.params.use_fp16 = True
config.model.params.unet_config.params.use_fp16 = True
model = instantiate_from_config(config.model).cpu()
print(f"Loaded model config from [{config_path}]")
return model

View File

@ -0,0 +1,300 @@
"""
Copyright (c) Alibaba, Inc. and its affiliates.
"""
import os
import cv2
import numpy as np
import math
import traceback
from easydict import EasyDict as edict
import time
from iopaint.model.anytext.ocr_recog.RecModel import RecModel
import torch
import torch.nn.functional as F
def min_bounding_rect(img):
ret, thresh = cv2.threshold(img, 127, 255, 0)
contours, hierarchy = cv2.findContours(
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
if len(contours) == 0:
print("Bad contours, using fake bbox...")
return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
max_contour = max(contours, key=cv2.contourArea)
rect = cv2.minAreaRect(max_contour)
box = cv2.boxPoints(rect)
box = np.int0(box)
# sort
x_sorted = sorted(box, key=lambda x: x[0])
left = x_sorted[:2]
right = x_sorted[2:]
left = sorted(left, key=lambda x: x[1])
(tl, bl) = left
right = sorted(right, key=lambda x: x[1])
(tr, br) = right
if tl[1] > bl[1]:
(tl, bl) = (bl, tl)
if tr[1] > br[1]:
(tr, br) = (br, tr)
return np.array([tl, tr, br, bl])
def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
model_file_path = model_dir
if model_file_path is not None and not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(model_file_path))
if is_onnx:
import onnxruntime as ort
sess = ort.InferenceSession(
model_file_path, providers=["CPUExecutionProvider"]
) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
return sess
else:
if model_lang == "ch":
n_class = 6625
elif model_lang == "en":
n_class = 97
else:
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
rec_config = edict(
in_channels=3,
backbone=edict(
type="MobileNetV1Enhance",
scale=0.5,
last_conv_stride=[1, 2],
last_pool_type="avg",
),
neck=edict(
type="SequenceEncoder",
encoder_type="svtr",
dims=64,
depth=2,
hidden_dims=120,
use_guide=True,
),
head=edict(
type="CTCHead",
fc_decay=0.00001,
out_channels=n_class,
return_feats=True,
),
)
rec_model = RecModel(rec_config)
if model_file_path is not None:
rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
rec_model.eval()
return rec_model.eval()
def _check_image_file(path):
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
return any([path.lower().endswith(e) for e in img_end])
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
class TextRecognizer(object):
def __init__(self, args, predictor):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.rec_batch_num = args.rec_batch_num
self.predictor = predictor
self.chars = self.get_char_dict(args.rec_char_dict_path)
self.char2id = {x: i for i, x in enumerate(self.chars)}
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
self.use_fp16 = args.use_fp16
# img: CHW
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[0]
imgW = int((imgH * max_wh_ratio))
h, w = img.shape[1:]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = torch.nn.functional.interpolate(
img.unsqueeze(0),
size=(imgH, resized_w),
mode="bilinear",
align_corners=True,
)
resized_image /= 255.0
resized_image -= 0.5
resized_image /= 0.5
padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
padding_im[:, :, 0:resized_w] = resized_image[0]
return padding_im
# img_list: list of tensors with shape chw 0-255
def pred_imglist(self, img_list, show_debug=False, is_ori=False):
img_num = len(img_list)
assert img_num > 0
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[2] / float(img.shape[1]))
# Sorting can speed up the recognition process
indices = torch.from_numpy(np.argsort(np.array(width_list)))
batch_num = self.rec_batch_num
preds_all = [None] * img_num
preds_neck_all = [None] * img_num
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[1:]
if h > w * 1.2:
img = img_list[indices[ino]]
img = torch.transpose(img, 1, 2).flip(dims=[1])
img_list[indices[ino]] = img
h, w = img.shape[1:]
# wh_ratio = w * 1.0 / h
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
if self.use_fp16:
norm_img = norm_img.half()
norm_img = norm_img.unsqueeze(0)
norm_img_batch.append(norm_img)
norm_img_batch = torch.cat(norm_img_batch, dim=0)
if show_debug:
for i in range(len(norm_img_batch)):
_img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
_img = (_img + 0.5) * 255
_img = _img[:, :, ::-1]
file_name = f"{indices[beg_img_no + i]}"
file_name = file_name + "_ori" if is_ori else file_name
cv2.imwrite(file_name + ".jpg", _img)
if self.is_onnx:
input_dict = {}
input_dict[self.predictor.get_inputs()[0].name] = (
norm_img_batch.detach().cpu().numpy()
)
outputs = self.predictor.run(None, input_dict)
preds = {}
preds["ctc"] = torch.from_numpy(outputs[0])
preds["ctc_neck"] = [torch.zeros(1)] * img_num
else:
preds = self.predictor(norm_img_batch)
for rno in range(preds["ctc"].shape[0]):
preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
def get_char_dict(self, character_dict_path):
character_str = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode("utf-8").strip("\n").strip("\r\n")
character_str.append(line)
dict_character = list(character_str)
dict_character = ["sos"] + dict_character + [" "] # eos is space
return dict_character
def get_text(self, order):
char_list = [self.chars[text_id] for text_id in order]
return "".join(char_list)
def decode(self, mat):
text_index = mat.detach().cpu().numpy().argmax(axis=1)
ignored_tokens = [0]
selection = np.ones(len(text_index), dtype=bool)
selection[1:] = text_index[1:] != text_index[:-1]
for ignored_token in ignored_tokens:
selection &= text_index != ignored_token
return text_index[selection], np.where(selection)[0]
def get_ctcloss(self, preds, gt_text, weight):
if not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight).to(preds.device)
ctc_loss = torch.nn.CTCLoss(reduction="none")
log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
targets = []
target_lengths = []
for t in gt_text:
targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
target_lengths += [len(t)]
targets = torch.tensor(targets).to(preds.device)
target_lengths = torch.tensor(target_lengths).to(preds.device)
input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
preds.device
)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
loss = loss / input_lengths * weight
return loss
def main():
rec_model_dir = "./ocr_weights/ppv3_rec.pth"
predictor = create_predictor(rec_model_dir)
args = edict()
args.rec_image_shape = "3, 48, 320"
args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
args.rec_batch_num = 6
text_recognizer = TextRecognizer(args, predictor)
image_dir = "./test_imgs_cn"
gt_text = ["韩国小馆"] * 14
image_file_list = get_image_file_list(image_dir)
valid_image_file_list = []
img_list = []
for image_file in image_file_list:
img = cv2.imread(image_file)
if img is None:
print("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
try:
tic = time.time()
times = []
for i in range(10):
preds, _ = text_recognizer.pred_imglist(img_list) # get text
preds_all = preds.softmax(dim=2)
times += [(time.time() - tic) * 1000.0]
tic = time.time()
print(times)
print(np.mean(times[1:]) / len(preds_all))
weight = np.ones(len(gt_text))
loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
for i in range(len(valid_image_file_list)):
pred = preds_all[i]
order, idx = text_recognizer.decode(pred)
text = text_recognizer.get_text(order)
print(
f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
)
except Exception as E:
print(traceback.format_exc(), E)
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,218 @@
import torch
import torch.nn.functional as F
from contextlib import contextmanager
from iopaint.model.anytext.ldm.modules.diffusionmodules.model import Encoder, Decoder
from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from iopaint.model.anytext.ldm.util import instantiate_from_config
from iopaint.model.anytext.ldm.modules.ema import LitEma
class AutoencoderKL(torch.nn.Module):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@ -0,0 +1,354 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
# cbs = len(ctmp[0])
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img], "index": [10000]}
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['index'].append(index)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
elif isinstance(c[k], dict):
c_in[k] = dict()
for key in c[k]:
if isinstance(c[k][key], list):
if not isinstance(c[k][key][0], torch.Tensor):
continue
c_in[k][key] = [torch.cat([
unconditional_conditioning[k][key][i],
c[k][key][i]]) for i in range(len(c[k][key]))]
else:
c_in[k][key] = torch.cat([
unconditional_conditioning[k][key],
c[k][key]])
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
from .sampler import DPMSolverSampler

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,87 @@
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
return x.to(device), None

View File

@ -0,0 +1,244 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from iopaint.model.anytext.ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -0,0 +1,22 @@
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

View File

@ -0,0 +1,360 @@
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION == "fp32":
with torch.autocast(enabled=False, device_type="cuda"):
q, k = q.float(), k.float()
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
else:
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", sim, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
class SDPACrossAttention(CrossAttention):
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
head_dim = inner_dim // h
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
del q_in, k_in, v_in
dtype = q.dtype
if _ATTN_PRECISION == "fp32":
q, k, v = q.float(), k.float(), v.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, h * head_dim
)
hidden_states = hidden_states.to(dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
attn_cls = SDPACrossAttention
else:
attn_cls = CrossAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in

View File

@ -0,0 +1,973 @@
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class AttnBlock2_0(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
# output: [1, 512, 64, 64]
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
# q = q.reshape(b, c, h * w).transpose()
# q = q.permute(0, 2, 1) # b,hw,c
# k = k.reshape(b, c, h * w) # b,c,hw
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = hidden_states.to(q.dtype)
h_ = self.proj_out(hidden_states)
return x + h_
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in [
"vanilla",
"vanilla-xformers",
"memory-efficient-cross-attn",
"linear",
"none",
], f"attn_type {attn_type} unknown"
assert attn_kwargs is None
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# print(f"Using torch.nn.functional.scaled_dot_product_attention")
return AttnBlock2_0(in_channels)
return AttnBlock(in_channels)
class Model(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList(
[
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(
in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0,
),
ResnetBlock(
in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0,
dropout=0.0,
),
ResnetBlock(
in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0,
),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True),
]
)
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
ch,
num_res_blocks,
resolution,
ch_mult=(2, 2),
dropout=0.0,
):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1
)
self.res_block1 = nn.ModuleList(
[
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0,
)
for _ in range(depth)
]
)
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList(
[
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0,
)
for _ in range(depth)
]
)
self.conv_out = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(
x,
size=(
int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor)),
),
)
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(
self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(
in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(
self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(
out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1.0 + (out_size % in_size)
print(
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
)
self.rescaler = LatentRescaler(
factor=factor_up,
in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels,
)
self.decoder = Decoder(
out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)],
)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
)
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1
)
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
return x

View File

@ -0,0 +1,786 @@
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
from iopaint.model.anytext.ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
def forward(self,x):
return self.up(x)
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.use_fp16 = use_fp16
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)

View File

@ -0,0 +1,81 @@
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from iopaint.model.anytext.ldm.util import default
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super(AbstractLowScaleModel, self).__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level

View File

@ -0,0 +1,271 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import os
import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from iopaint.model.anytext.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
# return super().forward(x.float()).type(x.dtype)
return super().forward(x).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@ -0,0 +1,80 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

@ -0,0 +1,411 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers import (
T5Tokenizer,
T5EncoderModel,
CLIPTokenizer,
CLIPTextModel,
AutoProcessor,
CLIPVisionModelWithProjection,
)
from iopaint.model.anytext.ldm.util import count_params
def _expand_mask(mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0.0 and not disable_dropout:
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == "hidden"
)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class FrozenCLIPEmbedderT3(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
use_vision=False,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
if use_vision:
self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
self.processor = AutoProcessor.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def embedding_forward(
self,
input_ids=None,
position_ids=None,
inputs_embeds=None,
embedding_manager=None,
):
seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
self.transformer.text_model.embeddings
)
def encoder_forward(
self,
inputs_embeds,
attention_mask=None,
causal_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype
).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.final_layer_norm(last_hidden_state)
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager=embedding_manager,
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs)
return z
def encode(self, text, **kwargs):
return self(text, **kwargs)

View File

@ -0,0 +1,197 @@
import importlib
import torch
from torch import optim
import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
nc = int(32 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x,torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config, **kwargs):
if "target" not in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
ema_power=1., param_names=()):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
ema_power=ema_power, param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss

View File

@ -0,0 +1,45 @@
import cv2
import os
from anytext_pipeline import AnyTextPipeline
from utils import save_images
seed = 66273235
# seed_everything(seed)
pipe = AnyTextPipeline(
ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt",
font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
use_fp16=False,
device="mps",
)
img_save_folder = "SaveImages"
rgb_image = cv2.imread(
"/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg"
)[..., ::-1]
masked_image = cv2.imread(
"/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png"
)[..., ::-1]
rgb_image = cv2.resize(rgb_image, (512, 512))
masked_image = cv2.resize(masked_image, (512, 512))
# results: list of rgb ndarray
results, rtn_code, rtn_warning = pipe(
prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
image=rgb_image,
masked_image=masked_image,
num_inference_steps=20,
strength=1.0,
guidance_scale=9.0,
height=rgb_image.shape[0],
width=rgb_image.shape[1],
seed=seed,
sort_priority="y",
)
if rtn_code >= 0:
save_images(results, img_save_folder)
print(f"Done, result images are saved in: {img_save_folder}")

View File

@ -0,0 +1,210 @@
from torch import nn
import torch
from .RecSVTR import Block
class Swish(nn.Module):
def __int__(self):
super(Swish, self).__int__()
def forward(self,x):
return x*torch.sigmoid(x)
class Im2Im(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
return x
class Im2Seq(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
# assert H == 1
x = x.reshape(B, C, H * W)
x = x.permute((0, 2, 1))
return x
class EncoderWithRNN(nn.Module):
def __init__(self, in_channels,**kwargs):
super(EncoderWithRNN, self).__init__()
hidden_size = kwargs.get('hidden_size', 256)
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True)
def forward(self, x):
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
return x
class SequenceEncoder(nn.Module):
def __init__(self, in_channels, encoder_type='rnn', **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == 'reshape':
self.only_reshape = True
else:
support_encoder_dict = {
'reshape': Im2Seq,
'rnn': EncoderWithRNN,
'svtr': EncoderWithSVTR
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels,**kwargs)
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
if self.encoder_type != 'svtr':
x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x)
x = self.encoder_reshape(x)
return x
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act=nn.GELU):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias=bias_attr)
self.norm = nn.BatchNorm2d(out_channels)
self.act = Swish()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class EncoderWithSVTR(nn.Module):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.,
qk_scale=None):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels, in_channels // 8, padding=1, act='swish')
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act='swish')
self.svtr_block = nn.ModuleList([
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer='Global',
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer='swish',
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer='nn.LayerNorm',
epsilon=1e-05,
prenorm=False) for i in range(depth)
])
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.conv3 = ConvBNLayer(
hidden_dims, in_channels, kernel_size=1, act='swish')
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act='swish')
self.conv1x1 = ConvBNLayer(
in_channels // 8, dims, kernel_size=1, act='swish')
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).permute(0, 2, 1)
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
z = self.conv3(z)
z = torch.cat((h, z), dim=1)
z = self.conv1x1(self.conv4(z))
return z
if __name__=="__main__":
svtrRNN = EncoderWithSVTR(56)
print(svtrRNN)

View File

@ -0,0 +1,48 @@
from torch import nn
class CTCHead(nn.Module):
def __init__(self,
in_channels,
out_channels=6625,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs):
super(CTCHead, self).__init__()
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
bias=True,)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
bias=True,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
bias=True,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = dict()
result['ctc'] = predicts
result['ctc_neck'] = x
else:
result = predicts
return result

View File

@ -0,0 +1,45 @@
from torch import nn
from .RNN import SequenceEncoder, Im2Seq, Im2Im
from .RecMv1_enhance import MobileNetV1Enhance
from .RecCTCHead import CTCHead
backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
head_dict = {'CTCHead':CTCHead}
class RecModel(nn.Module):
def __init__(self, config):
super().__init__()
assert 'in_channels' in config, 'in_channels must in model config'
backbone_type = config.backbone.pop('type')
assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
neck_type = config.neck.pop('type')
assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
head_type = config.head.pop('type')
assert head_type in head_dict, f'head.type must in {head_dict}'
self.head = head_dict[head_type](self.neck.out_channels, **config.head)
self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
def load_3rd_state_dict(self, _3rd_name, _state):
self.backbone.load_3rd_state_dict(_3rd_name, _state)
self.neck.load_3rd_state_dict(_3rd_name, _state)
self.head.load_3rd_state_dict(_3rd_name, _state)
def forward(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head(x)
return x
def encode(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head.ctc_encoder(x)
return x

View File

@ -0,0 +1,232 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import Activation
class ConvBNLayer(nn.Module):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='hard_swish'):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias=False)
self._batch_norm = nn.BatchNorm2d(
num_filters,
)
if self.act is not None:
self._act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class DepthwiseSeparable(nn.Module):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
dw_size=3,
padding=1,
use_se=False):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=dw_size,
stride=stride,
padding=padding,
num_groups=int(num_groups * scale))
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
if self.use_se:
y = self._se(y)
y = self._pointwise_conv(y)
return y
class MobileNetV1Enhance(nn.Module):
def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__()
self.scale = scale
self.block_list = []
self.conv1 = ConvBNLayer(
num_channels=in_channels,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale)
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=1,
scale=scale)
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale)
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=(2, 1),
scale=scale)
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale)
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=(2, 1),
scale=scale)
self.block_list.append(conv4_2)
for _ in range(5):
conv5 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
dw_size=5,
padding=2,
scale=scale,
use_se=False)
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=(2, 1),
dw_size=5,
padding=2,
scale=scale,
use_se=True)
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
scale=scale)
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
def forward(self, inputs):
y = self.conv1(inputs)
y = self.block_list(y)
y = self.pool(y)
return y
def hardsigmoid(x):
return F.relu6(x + 3., inplace=True) / 6.
class SEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.conv2 = nn.Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
bias=True)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs)
x = torch.mul(inputs, outputs)
return x

View File

@ -0,0 +1,591 @@
import torch
import torch.nn as nn
import numpy as np
from torch.nn.init import trunc_normal_, zeros_, ones_
from torch.nn import functional
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0. or not training:
return x
keep_prob = torch.tensor(1 - drop_prob)
shape = (x.size()[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
random_tensor = torch.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class Swish(nn.Module):
def __int__(self):
super(Swish, self).__int__()
def forward(self,x):
return x*torch.sigmoid(x)
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act=nn.GELU):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias=bias_attr)
self.norm = nn.BatchNorm2d(out_channels)
self.act = act()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
if isinstance(act_layer, str):
self.act = Swish()
else:
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
num_heads=8,
HW=(8, 25),
local_k=(3, 3), ):
super().__init__()
self.HW = HW
self.dim = dim
self.local_mixer = nn.Conv2d(
dim,
dim,
local_k,
1, (local_k[0] // 2, local_k[1] // 2),
groups=num_heads,
# weight_attr=ParamAttr(initializer=KaimingNormal())
)
def forward(self, x):
h = self.HW[0]
w = self.HW[1]
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
x = self.local_mixer(x)
x = x.flatten(2).transpose([0, 2, 1])
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
mixer='Global',
HW=(8, 25),
local_k=(7, 11),
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.HW = HW
if HW is not None:
H = HW[0]
W = HW[1]
self.N = H * W
self.C = dim
if mixer == 'Local' and HW is not None:
hk = local_k[0]
wk = local_k[1]
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
for h in range(0, H):
for w in range(0, W):
mask[h * W + w, h:h + hk, w:w + wk] = 0.
mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
2].flatten(1)
mask_inf = torch.full([H * W, H * W],fill_value=float('-inf'))
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask[None,None,:]
# self.mask = mask.unsqueeze([0, 1])
self.mixer = mixer
def forward(self, x):
if self.HW is not None:
N = self.N
C = self.C
else:
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = (q.matmul(k.permute((0, 1, 3, 2))))
if self.mixer == 'Local':
attn += self.mask
attn = functional.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mixer='Global',
local_mixer=(7, 11),
HW=(8, 25),
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
prenorm=True):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm1 = norm_layer(dim)
if mixer == 'Global' or mixer == 'Local':
self.mixer = Attention(
dim,
num_heads=num_heads,
mixer=mixer,
HW=HW,
local_k=local_mixer,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
elif mixer == 'Conv':
self.mixer = ConvMixer(
dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
else:
raise TypeError("The mixer must be one of [Global, Local, Conv]")
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.prenorm = prenorm
def forward(self, x):
if self.prenorm:
x = self.norm1(x + self.drop_path(self.mixer(x)))
x = self.norm2(x + self.drop_path(self.mlp(x)))
else:
x = x + self.drop_path(self.mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self,
img_size=(32, 100),
in_channels=3,
embed_dim=768,
sub_num=2):
super().__init__()
num_patches = (img_size[1] // (2 ** sub_num)) * \
(img_size[0] // (2 ** sub_num))
self.img_size = img_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
if sub_num == 2:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False))
if sub_num == 3:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 4,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False),
ConvBNLayer(
in_channels=embed_dim // 4,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False))
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
class SubSample(nn.Module):
def __init__(self,
in_channels,
out_channels,
types='Pool',
stride=(2, 1),
sub_norm='nn.LayerNorm',
act=None):
super().__init__()
self.types = types
if types == 'Pool':
self.avgpool = nn.AvgPool2d(
kernel_size=(3, 5), stride=stride, padding=(1, 2))
self.maxpool = nn.MaxPool2d(
kernel_size=(3, 5), stride=stride, padding=(1, 2))
self.proj = nn.Linear(in_channels, out_channels)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
# weight_attr=ParamAttr(initializer=KaimingNormal())
)
self.norm = eval(sub_norm)(out_channels)
if act is not None:
self.act = act()
else:
self.act = None
def forward(self, x):
if self.types == 'Pool':
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = (x1 + x2) * 0.5
out = self.proj(x.flatten(2).permute((0, 2, 1)))
else:
x = self.conv(x)
out = x.flatten(2).permute((0, 2, 1))
out = self.norm(out)
if self.act is not None:
out = self.act(out)
return out
class SVTRNet(nn.Module):
def __init__(
self,
img_size=[48, 100],
in_channels=3,
embed_dim=[64, 128, 256],
depth=[3, 6, 3],
num_heads=[2, 4, 8],
mixer=['Local'] * 6 + ['Global'] *
6, # Local atten, Global atten, Conv
local_mixer=[[7, 11], [7, 11], [7, 11]],
patch_merging='Conv', # Conv, Pool, None
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
last_drop=0.1,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer='nn.LayerNorm',
sub_norm='nn.LayerNorm',
epsilon=1e-6,
out_channels=192,
out_char_num=25,
block_unit='Block',
act='nn.GELU',
last_stage=True,
sub_num=2,
prenorm=True,
use_lenhead=False,
**kwargs):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = out_channels
self.prenorm = prenorm
patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim[0],
sub_num=sub_num)
num_patches = self.patch_embed.num_patches
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
# self.pos_embed = self.create_parameter(
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
# self.add_parameter("pos_embed", self.pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
Block_unit = eval(block_unit)
dpr = np.linspace(0, drop_path_rate, sum(depth))
self.blocks1 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[0],
num_heads=num_heads[0],
mixer=mixer[0:depth[0]][i],
HW=self.HW,
local_mixer=local_mixer[0],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[0:depth[0]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth[0])
]
)
if patch_merging is not None:
self.sub_sample1 = SubSample(
embed_dim[0],
embed_dim[1],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging)
HW = [self.HW[0] // 2, self.HW[1]]
else:
HW = self.HW
self.patch_merging = patch_merging
self.blocks2 = nn.ModuleList([
Block_unit(
dim=embed_dim[1],
num_heads=num_heads[1],
mixer=mixer[depth[0]:depth[0] + depth[1]][i],
HW=HW,
local_mixer=local_mixer[1],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth[1])
])
if patch_merging is not None:
self.sub_sample2 = SubSample(
embed_dim[1],
embed_dim[2],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging)
HW = [self.HW[0] // 4, self.HW[1]]
else:
HW = self.HW
self.blocks3 = nn.ModuleList([
Block_unit(
dim=embed_dim[2],
num_heads=num_heads[2],
mixer=mixer[depth[0] + depth[1]:][i],
HW=HW,
local_mixer=local_mixer[2],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] + depth[1]:][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth[2])
])
self.last_stage = last_stage
if last_stage:
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
self.last_conv = nn.Conv2d(
in_channels=embed_dim[2],
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=last_drop)
if not prenorm:
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
self.use_lenhead = use_lenhead
if use_lenhead:
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
self.hardswish_len = nn.Hardswish()
self.dropout_len = nn.Dropout(
p=last_drop)
trunc_normal_(self.pos_embed,std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight,std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample1(
x.permute([0, 2, 1]).reshape(
[-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
for blk in self.blocks2:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample2(
x.permute([0, 2, 1]).reshape(
[-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
for blk in self.blocks3:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.use_lenhead:
len_x = self.len_conv(x.mean(1))
len_x = self.dropout_len(self.hardswish_len(len_x))
if self.last_stage:
if self.patch_merging is not None:
h = self.HW[0] // 4
else:
h = self.HW[0]
x = self.avg_pool(
x.permute([0, 2, 1]).reshape(
[-1, self.embed_dim[2], h, self.HW[1]]))
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
if self.use_lenhead:
return x, len_x
return x
if __name__=="__main__":
a = torch.rand(1,3,48,100)
svtr = SVTRNet()
out = svtr(a)
print(svtr)
print(out.size())

View File

@ -0,0 +1,74 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
# out = max(0, min(1, slop*x+offset))
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
class GELU(nn.Module):
def __init__(self, inplace=True):
super(GELU, self).__init__()
self.inplace = inplace
def forward(self, x):
return torch.nn.functional.gelu(x)
class Swish(nn.Module):
def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
if self.inplace:
x.mul_(torch.sigmoid(x))
return x
else:
return x*torch.sigmoid(x)
class Activation(nn.Module):
def __init__(self, act_type, inplace=True):
super(Activation, self).__init__()
act_type = act_type.lower()
if act_type == 'relu':
self.act = nn.ReLU(inplace=inplace)
elif act_type == 'relu6':
self.act = nn.ReLU6(inplace=inplace)
elif act_type == 'sigmoid':
raise NotImplementedError
elif act_type == 'hard_sigmoid':
self.act = Hsigmoid(inplace)
elif act_type == 'hard_swish':
self.act = Hswish(inplace=inplace)
elif act_type == 'leakyrelu':
self.act = nn.LeakyReLU(inplace=inplace)
elif act_type == 'gelu':
self.act = GELU(inplace=inplace)
elif act_type == 'swish':
self.act = Swish(inplace=inplace)
else:
raise NotImplementedError
def forward(self, inputs):
return self.act(inputs)

View File

@ -0,0 +1,95 @@
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,151 @@
import os
import datetime
import cv2
import numpy as np
from PIL import Image, ImageDraw
def save_images(img_list, folder):
if not os.path.exists(folder):
os.makedirs(folder)
now = datetime.datetime.now()
date_str = now.strftime("%Y-%m-%d")
folder_path = os.path.join(folder, date_str)
if not os.path.exists(folder_path):
os.makedirs(folder_path)
time_str = now.strftime("%H_%M_%S")
for idx, img in enumerate(img_list):
image_number = idx + 1
filename = f"{time_str}_{image_number}.jpg"
save_path = os.path.join(folder_path, filename)
cv2.imwrite(save_path, img[..., ::-1])
def check_channels(image):
channels = image.shape[2] if len(image.shape) == 3 else 1
if channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif channels > 3:
image = image[:, :, :3]
return image
def resize_image(img, max_length=768):
height, width = img.shape[:2]
max_dimension = max(height, width)
if max_dimension > max_length:
scale_factor = max_length / max_dimension
new_width = int(round(width * scale_factor))
new_height = int(round(height * scale_factor))
new_size = (new_width, new_height)
img = cv2.resize(img, new_size)
height, width = img.shape[:2]
img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
return img
def insert_spaces(string, nSpace):
if nSpace == 0:
return string
new_string = ""
for char in string:
new_string += char + " " * nSpace
return new_string[:-nSpace]
def draw_glyph(font, text):
g_size = 50
W, H = (512, 80)
new_font = font.font_variant(size=g_size)
img = Image.new(mode="1", size=(W, H), color=0)
draw = ImageDraw.Draw(img)
left, top, right, bottom = new_font.getbbox(text)
text_width = max(right - left, 5)
text_height = max(bottom - top, 5)
ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
new_font = font.font_variant(size=int(g_size * ratio))
text_width, text_height = new_font.getsize(text)
offset_x, offset_y = new_font.getoffset(text)
x = (img.width - text_width) // 2
y = (img.height - text_height) // 2 - offset_y // 2
draw.text((x, y), text, font=new_font, fill="white")
img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
return img
def draw_glyph2(
font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True
):
enlarge_polygon = polygon * scale
rect = cv2.minAreaRect(enlarge_polygon)
box = cv2.boxPoints(rect)
box = np.int0(box)
w, h = rect[1]
angle = rect[2]
if angle < -45:
angle += 90
angle = -angle
if w < h:
angle += 90
vert = False
if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
_w = max(box[:, 0]) - min(box[:, 0])
_h = max(box[:, 1]) - min(box[:, 1])
if _h >= _w:
vert = True
angle = 0
img = np.zeros((height * scale, width * scale, 3), np.uint8)
img = Image.fromarray(img)
# infer font size
image4ratio = Image.new("RGB", img.size, "white")
draw = ImageDraw.Draw(image4ratio)
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
text_w = min(w, h) * (_tw / _th)
if text_w <= max(w, h):
# add space
if len(text) > 1 and not vert and add_space:
for i in range(1, 100):
text_space = insert_spaces(text, i)
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
if min(w, h) * (_tw2 / _th2) > max(w, h):
break
text = insert_spaces(text, i - 1)
font_size = min(w, h) * 0.80
else:
shrink = 0.75 if vert else 0.85
font_size = min(w, h) / (text_w / max(w, h)) * shrink
new_font = font.font_variant(size=int(font_size))
left, top, right, bottom = new_font.getbbox(text)
text_width = right - left
text_height = bottom - top
layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(layer)
if not vert:
draw.text(
(rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
text,
font=new_font,
fill=(255, 255, 255, 255),
)
else:
x_s = min(box[:, 0]) + _w // 2 - text_height // 2
y_s = min(box[:, 1])
for c in text:
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
_, _t, _, _b = new_font.getbbox(c)
y_s += _b
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
x_offset = int((img.width - rotated_layer.width) / 2)
y_offset = int((img.height - rotated_layer.height) / 2)
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
return img

405
inpaint/model/base.py Normal file
View File

@ -0,0 +1,405 @@
import abc
from typing import Optional
import cv2
import torch
import numpy as np
from loguru import logger
from iopaint.helper import (
boxes_from_mask,
resize_max_size,
pad_img_to_modulo,
switch_mps_device,
)
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
from .helper.g_diffuser_bot import expand_image
from .utils import get_scheduler
class InpaintModel:
name = "base"
min_size: Optional[int] = None
pad_mod = 8
pad_to_square = False
is_erase_model = False
def __init__(self, device, **kwargs):
"""
Args:
device:
"""
device = switch_mps_device(self.name, device)
self.device = device
self.init_model(device, **kwargs)
@abc.abstractmethod
def init_model(self, device, **kwargs): ...
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool:
return False
@abc.abstractmethod
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W, 1] 255 masks 区域
return: BGR IMAGE
"""
...
@staticmethod
def download(): ...
def _pad_forward(self, image, mask, config: InpaintRequest):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
)
pad_mask = pad_img_to_modulo(
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
)
# logger.info(f"final forward pad size: {pad_image.shape}")
image, mask = self.forward_pre_process(image, mask, config)
result = self.forward(pad_image, pad_mask, config)
result = result[0:origin_height, 0:origin_width, :]
result, image, mask = self.forward_post_process(result, image, mask, config)
if config.sd_keep_unmasked_area:
mask = mask[:, :, np.newaxis]
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
return result
def forward_pre_process(self, image, mask, config):
return image, mask
def forward_post_process(self, result, image, mask, config):
return result, image, mask
@torch.no_grad()
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
inpaint_result = None
# logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info("Run crop strategy")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box, config)
crop_result.append((crop_image, crop_box))
inpaint_result = image[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
downsize_image = resize_max_size(
image, size_limit=config.hd_strategy_resize_limit
)
downsize_mask = resize_max_size(
mask, size_limit=config.hd_strategy_resize_limit
)
logger.info(
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
)
inpaint_result = self._pad_forward(
downsize_image, downsize_mask, config
)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
original_pixel_indices
]
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
def _crop_box(self, image, mask, box, config: InpaintRequest):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE, (l, r, r, b)
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[:2]
w = box_w + config.hd_strategy_crop_margin * 2
h = box_h + config.hd_strategy_crop_margin * 2
_l = cx - w // 2
_r = cx + w // 2
_t = cy - h // 2
_b = cy + h // 2
l = max(_l, 0)
r = min(_r, img_w)
t = max(_t, 0)
b = min(_b, img_h)
# try to get more context when crop around image edge
if _l < 0:
r += abs(_l)
if _r > img_w:
l -= _r - img_w
if _t < 0:
b += abs(_t)
if _b > img_h:
t -= _b - img_h
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
# logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return crop_img, crop_mask, [l, t, r, b]
def _calculate_cdf(self, histogram):
cdf = histogram.cumsum()
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
def _calculate_lookup(self, source_cdf, reference_cdf):
lookup_table = np.zeros(256)
lookup_val = 0
for source_index, source_val in enumerate(source_cdf):
for reference_index, reference_val in enumerate(reference_cdf):
if reference_val >= source_val:
lookup_val = reference_index
break
lookup_table[source_index] = lookup_val
return lookup_table
def _match_histograms(self, source, reference, mask):
transformed_channels = []
if len(mask.shape) == 3:
mask = mask[:, :, -1]
for channel in range(source.shape[-1]):
source_channel = source[:, :, channel]
reference_channel = reference[:, :, channel]
# only calculate histograms for non-masked parts
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
reference_histogram, _ = np.histogram(
reference_channel[mask == 0], 256, [0, 256]
)
source_cdf = self._calculate_cdf(source_histogram)
reference_cdf = self._calculate_cdf(reference_histogram)
lookup = self._calculate_lookup(source_cdf, reference_cdf)
transformed_channels.append(cv2.LUT(source_channel, lookup))
result = cv2.merge(transformed_channels)
result = cv2.convertScaleAbs(result)
return result
def _apply_cropper(self, image, mask, config: InpaintRequest):
img_h, img_w = image.shape[:2]
l, t, w, h = (
config.croper_x,
config.croper_y,
config.croper_width,
config.croper_height,
)
r = l + w
b = t + h
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
return crop_img, crop_mask, (l, t, r, b)
def _run_box(self, image, mask, box, config: InpaintRequest):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
class DiffusionInpaintModel(InpaintModel):
def __init__(self, device, **kwargs):
self.model_info = kwargs["model_info"]
self.model_id_or_path = self.model_info.path
super().__init__(device, **kwargs)
@torch.no_grad()
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
elif config.use_extender:
inpaint_result = self._do_outpainting(image, config)
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result
def _do_outpainting(self, image, config: InpaintRequest):
# cropper 和 image 在同一个坐标系下croper_x/y 可能为负数
# 从 image 中 crop 出 outpainting 区域
image_h, image_w = image.shape[:2]
cropper_l = config.extender_x
cropper_t = config.extender_y
cropper_r = config.extender_x + config.extender_width
cropper_b = config.extender_y + config.extender_height
image_l = 0
image_t = 0
image_r = image_w
image_b = image_h
# 类似求 IOU
l = max(cropper_l, image_l)
t = max(cropper_t, image_t)
r = min(cropper_r, image_r)
b = min(cropper_b, image_b)
assert (
0 <= l < r and 0 <= t < b
), f"cropper and image not overlap, {l},{t},{r},{b}"
cropped_image = image[t:b, l:r, :]
padding_l = max(0, image_l - cropper_l)
padding_t = max(0, image_t - cropper_t)
padding_r = max(0, cropper_r - image_r)
padding_b = max(0, cropper_b - image_b)
expanded_image, mask_image = expand_image(
cropped_image,
left=padding_l,
top=padding_t,
right=padding_r,
bottom=padding_b,
)
# 最终扩大了的 image, BGR
expanded_cropped_result_image = self._scaled_pad_forward(
expanded_image, mask_image, config
)
# RGB -> BGR
outpainting_image = cv2.copyMakeBorder(
image,
left=padding_l,
top=padding_t,
right=padding_r,
bottom=padding_b,
borderType=cv2.BORDER_CONSTANT,
value=0,
)[:, :, ::-1]
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
paste_t = 0 if config.extender_y < 0 else config.extender_y
paste_l = 0 if config.extender_x < 0 else config.extender_x
outpainting_image[
paste_t : paste_t + expanded_cropped_result_image.shape[0],
paste_l : paste_l + expanded_cropped_result_image.shape[1],
:,
] = expanded_cropped_result_image
return outpainting_image
def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
if config.sd_scale != 1:
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
return inpaint_result
def set_scheduler(self, config: InpaintRequest):
scheduler_config = self.model.scheduler.config
sd_sampler = config.sd_sampler
if config.sd_lcm_lora and self.model_info.support_lcm_lora:
sd_sampler = SDSampler.lcm
logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
scheduler = get_scheduler(sd_sampler, scheduler_config)
self.model.scheduler = scheduler
def forward_pre_process(self, image, mask, config):
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return image, mask
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.use_extender and config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask

View File

View File

@ -0,0 +1,931 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D, get_down_block, get_up_block,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from .unet_2d_blocks import MidBlock2D
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class BrushNetOutput(BaseOutput):
"""
The output of [`BrushNetModel`].
Args:
up_block_res_samples (`tuple[torch.Tensor]`):
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's upsampling activations.
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
up_block_res_samples: Tuple[torch.Tensor]
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class BrushNetModel(ModelMixin, ConfigMixin):
"""
A BrushNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 5,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in_condition = nn.Conv2d(
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
self.down_blocks = nn.ModuleList([])
self.brushnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
# mid
mid_block_channel = block_out_channels[-1]
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_mid_block = brushnet_block
self.mid_block = MidBlock2D(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
dropout=0.0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
self.up_blocks = nn.ModuleList([])
self.brushnet_up_blocks = nn.ModuleList([])
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=reversed_num_attention_heads[i],
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
for _ in range(layers_per_block + 1):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 5,
):
r"""
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
brushnet = cls(
in_channels=unet.config.in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D'],
mid_block_type='MidBlock2D',
up_block_types=['UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'],
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
brushnet.conv_in_condition.bias = unet.conv_in.bias
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if brushnet.class_embedding:
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
return brushnet
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
brushnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
"""
The [`BrushNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
brushnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for BrushNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
Returns:
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.brushnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
else:
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.addition_embed_type is not None:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
sample = self.conv_in_condition(brushnet_cond)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. PaintingNet down blocks
brushnet_down_block_res_samples = ()
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
down_block_res_sample = brushnet_down_block(down_block_res_sample)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
# 5. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)
# 6. BrushNet mid blocks
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
# 7. up
up_block_res_samples = ()
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample, up_res_samples = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
return_res_samples=True
)
else:
sample, up_res_samples = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
return_res_samples=True
)
up_block_res_samples += up_res_samples
# 8. BrushNet up blocks
brushnet_up_block_res_samples = ()
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
up_block_res_sample = brushnet_up_block(up_block_res_sample)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0,
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
scales[:len(
brushnet_down_block_res_samples)])]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
scales[
len(brushnet_down_block_res_samples) + 1:])]
else:
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
brushnet_down_block_res_samples]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
if self.config.global_pool_conditions:
brushnet_down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
]
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
brushnet_up_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
]
if not return_dict:
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
return BrushNetOutput(
down_block_res_samples=brushnet_down_block_res_samples,
mid_block_res_sample=brushnet_mid_block_res_sample,
up_block_res_samples=brushnet_up_block_res_samples
)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
if __name__ == "__main__":
BrushNetModel.from_pretrained("/Users/cwq/data/models/brushnet/brushnet_random_mask", variant='fp16',
use_safetensors=True)

View File

@ -0,0 +1,322 @@
from typing import Union, Optional, Dict, Any, Tuple
import torch
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers
def brushnet_unet_forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
through the `self.time_embedding` layer to obtain the timestep embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2 ** self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
aug_emb = self.get_aug_embed(
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
encoder_hidden_states = self.process_encoder_hidden_states(
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True
down_block_res_samples = (sample,)
if is_brushnet:
sample = sample + down_block_add_samples.pop(0)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
for _ in range(
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
additional_residuals = {}
if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
for _ in range(
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale,
**additional_residuals)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
down_block_res_samples += res_samples
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_intrablock_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
if is_brushnet:
sample = sample + mid_block_add_sample
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
for _ in range(
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
for _ in range(
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
**additional_residuals,
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)

View File

@ -0,0 +1,157 @@
import PIL.Image
import cv2
import torch
from loguru import logger
import numpy as np
from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..original_sd_configs import get_config_files
from ..utils import (
handle_from_pretrained_exceptions,
get_torch_dtype,
enable_low_mem,
is_local_files_only,
)
from .brushnet import BrushNetModel
from .brushnet_unet_forward import brushnet_unet_forward
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
UpBlock2D_forward
from ...schema import InpaintRequest, ModelType
class BrushNetWrapper(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
self.model_info = kwargs["model_info"]
self.brushnet_method = kwargs["brushnet_method"]
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
self.torch_dtype = torch_dtype
model_kwargs = {
**kwargs.get("pipe_components", {}),
"local_files_only": is_local_files_only(**kwargs),
}
self.local_files_only = model_kwargs["local_files_only"]
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
"cpu_offload", False
)
if disable_nsfw_checker:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
self.model = StableDiffusionBrushNetPipeline.from_single_file(
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker,
original_config_file=get_config_files()['v1'],
brushnet=brushnet,
**model_kwargs,
)
else:
self.model = handle_from_pretrained_exceptions(
StableDiffusionBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
variant="fp16",
torch_dtype=torch_dtype,
brushnet=brushnet,
**model_kwargs,
)
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
for down_block in self.model.brushnet.down_blocks:
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
for up_block in self.model.brushnet.up_blocks:
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in self.model.unet.down_blocks:
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
else:
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
for up_block in self.model.unet.up_blocks:
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
else:
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
def switch_brushnet_method(self, new_method: str):
self.brushnet_method = new_method
brushnet = BrushNetModel.from_pretrained(
new_method,
resume_download=True,
local_files_only=self.local_files_only,
torch_dtype=self.torch_dtype,
).to(self.model.device)
self.model.brushnet = brushnet
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
img_h, img_w = image.shape[:2]
normalized_mask = mask[:, :].astype("float32") / 255.0
image = image * (1 - normalized_mask)
image = image.astype(np.uint8)
output = self.model(
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
num_inference_steps=config.sd_steps,
# strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
brushnet_conditioning_scale=config.brushnet_conditioning_scale,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,388 @@
from typing import Dict, Any, Optional, Tuple
import torch
from diffusers.models.resnet import ResnetBlock2D
from diffusers.utils import is_torch_version
from diffusers.utils.torch_utils import apply_freeu
from torch import nn
class MidBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
use_linear_projection: bool = False,
):
super().__init__()
self.has_cross_attention = False
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
for i in range(num_layers):
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for resnet in self.resnets[1:]:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
def DownBlock2D_forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0,
down_block_add_samples: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
down_block_add_samples: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
return_res_samples: Optional[bool] = False,
up_block_add_samples: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
if return_res_samples:
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0)
if return_res_samples:
return hidden_states, output_states
else:
return hidden_states
def UpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
return_res_samples: Optional[bool] = False,
up_block_add_samples: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
if return_res_samples:
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
if return_res_samples:
return hidden_states, output_states
else:
return hidden_states

194
inpaint/model/controlnet.py Normal file
View File

@ -0,0 +1,194 @@
import PIL.Image
import cv2
import torch
from diffusers import ControlNetModel
from loguru import logger
from iopaint.schema import InpaintRequest, ModelType
from .base import DiffusionInpaintModel
from .helper.controlnet_preprocess import (
make_canny_control_image,
make_openpose_control_image,
make_depth_control_image,
make_inpaint_control_image,
)
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
from .original_sd_configs import get_config_files
from .utils import (
get_scheduler,
handle_from_pretrained_exceptions,
get_torch_dtype,
enable_low_mem,
is_local_files_only,
)
class ControlNet(DiffusionInpaintModel):
name = "controlnet"
pad_mod = 8
min_size = 512
@property
def lcm_lora_id(self):
if self.model_info.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT,
]:
return "latent-consistency/lcm-lora-sdv1-5"
if self.model_info.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return "latent-consistency/lcm-lora-sdxl"
raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
def init_model(self, device: torch.device, **kwargs):
model_info = kwargs["model_info"]
controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info
self.controlnet_method = controlnet_method
model_kwargs = {
**kwargs.get("pipe_components", {}),
"local_files_only": is_local_files_only(**kwargs),
}
self.local_files_only = model_kwargs["local_files_only"]
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
"cpu_offload", False
)
if disable_nsfw_checker:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
self.torch_dtype = torch_dtype
original_config_file_name = "v1"
if model_info.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT,
]:
from diffusers import (
StableDiffusionControlNetInpaintPipeline as PipeClass,
)
original_config_file_name = "v1"
elif model_info.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
from diffusers import (
StableDiffusionXLControlNetInpaintPipeline as PipeClass,
)
original_config_file_name = "xl"
controlnet = ControlNetModel.from_pretrained(
pretrained_model_name_or_path=controlnet_method,
resume_download=True,
local_files_only=model_kwargs["local_files_only"],
torch_dtype=self.torch_dtype,
)
if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
self.model = PipeClass.from_single_file(
model_info.path,
controlnet=controlnet,
load_safety_checker=not disable_nsfw_checker,
torch_dtype=torch_dtype,
original_config_file=get_config_files()[original_config_file_name],
**model_kwargs,
)
else:
self.model = handle_from_pretrained_exceptions(
PipeClass.from_pretrained,
pretrained_model_name_or_path=model_info.path,
controlnet=controlnet,
variant="fp16",
torch_dtype=torch_dtype,
**model_kwargs,
)
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def switch_controlnet_method(self, new_method: str):
self.controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained(
new_method,
resume_download=True,
local_files_only=self.local_files_only,
torch_dtype=self.torch_dtype,
).to(self.model.device)
self.model.controlnet = controlnet
def _get_control_image(self, image, mask):
if "canny" in self.controlnet_method:
control_image = make_canny_control_image(image)
elif "openpose" in self.controlnet_method:
control_image = make_openpose_control_image(image)
elif "depth" in self.controlnet_method:
control_image = make_depth_control_image(image)
elif "inpaint" in self.controlnet_method:
control_image = make_inpaint_control_image(image, mask)
else:
raise NotImplementedError(f"{self.controlnet_method} not implemented")
return control_image
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
scheduler_config = self.model.scheduler.config
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
self.model.scheduler = scheduler
img_h, img_w = image.shape[:2]
control_image = self._get_control_image(image, mask)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
image = PIL.Image.fromarray(image)
output = self.model(
image=image,
mask_image=mask_image,
control_image=control_image,
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -0,0 +1,193 @@
import torch
import numpy as np
from tqdm import tqdm
from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
from loguru import logger
class DDIMSampler(object):
def __init__(self, model, schedule="linear"):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
# array([1])
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(self, steps, conditioning, batch_size, shape):
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# samples: 1,3,128,128
return self.ddim_sampling(
conditioning,
size,
quantize_denoised=False,
ddim_use_original_steps=False,
noise_dropout=0,
temperature=1.0,
)
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
ddim_use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device, dtype=cond.dtype)
timesteps = (
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
)
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
)
img, _ = outs
return img
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: # 没用
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0: # 没用
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0

1737
inpaint/model/fcf.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,68 @@
import torch
import PIL
import cv2
from PIL import Image
import numpy as np
from iopaint.helper import pad_img_to_modulo
def make_canny_control_image(image: np.ndarray) -> Image:
canny_image = cv2.Canny(image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
canny_image = PIL.Image.fromarray(canny_image)
control_image = canny_image
return control_image
def make_openpose_control_image(image: np.ndarray) -> Image:
from controlnet_aux import OpenposeDetector
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
control_image = processor(image, hand_and_face=True)
return control_image
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image,
(W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
)
return img
def make_depth_control_image(image: np.ndarray) -> Image:
from controlnet_aux import MidasDetector
midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(image, mod=64, square=False, min_size=512)
depth_image = midas(pad_image)
depth_image = depth_image[0:origin_height, 0:origin_width]
depth_image = depth_image[:, :, None]
depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2)
control_image = PIL.Image.fromarray(depth_image)
return control_image
def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor:
"""
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
"""
image = image.astype(np.float32) / 255.0
image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image

View File

@ -0,0 +1,41 @@
import torch
from transformers import PreTrainedModel
from ..utils import torch_gc
class CPUTextEncoderWrapper(PreTrainedModel):
def __init__(self, text_encoder, torch_dtype):
super().__init__(text_encoder.config)
self.config = text_encoder.config
self._device = text_encoder.device
# cpu not support float16
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype
del text_encoder
torch_gc()
def __call__(self, x, **kwargs):
input_device = x.device
original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
for k, v in original_output.items():
if isinstance(v, tuple):
original_output[k] = [
v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
]
else:
original_output[k] = v.to(input_device).to(self.torch_dtype)
return original_output
@property
def dtype(self):
return self.torch_dtype
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return self._device

View File

@ -0,0 +1,62 @@
import cv2
import numpy as np
def expand_image(cv2_img, top: int, right: int, bottom: int, left: int):
assert cv2_img.shape[2] == 3
origin_h, origin_w = cv2_img.shape[:2]
# TODO: which is better?
# new_img = np.ones((new_height, new_width, 3), np.uint8) * 255
new_img = cv2.copyMakeBorder(
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
)
inner_padding_left = 0 if left > 0 else 0
inner_padding_right = 0 if right > 0 else 0
inner_padding_top = 0 if top > 0 else 0
inner_padding_bottom = 0 if bottom > 0 else 0
mask_image = np.zeros(
(
origin_h - inner_padding_top - inner_padding_bottom,
origin_w - inner_padding_left - inner_padding_right,
),
np.uint8,
)
mask_image = cv2.copyMakeBorder(
mask_image,
top + inner_padding_top,
bottom + inner_padding_bottom,
left + inner_padding_left,
right + inner_padding_right,
cv2.BORDER_CONSTANT,
value=255,
)
# k = 2*int(min(origin_h, origin_w) // 6)+1
# k = 7
# mask_image = cv2.GaussianBlur(mask_image, (k, k), 0)
return new_img, mask_image
if __name__ == "__main__":
from pathlib import Path
current_dir = Path(__file__).parent.absolute().resolve()
image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg"
init_image = cv2.imread(str(image_path))
init_image, mask_image = expand_image(
init_image,
top=0,
right=0,
bottom=0,
left=100,
softness=20,
space=20,
)
print(mask_image.dtype, mask_image.min(), mask_image.max())
print(init_image.dtype, init_image.min(), init_image.max())
mask_image = mask_image.astype(np.uint8)
init_image = init_image.astype(np.uint8)
cv2.imwrite("expanded_image.png", init_image)
cv2.imwrite("expanded_mask.png", mask_image)

View File

@ -0,0 +1,64 @@
import PIL.Image
import cv2
import torch
from loguru import logger
from iopaint.const import INSTRUCT_PIX2PIX_NAME
from .base import DiffusionInpaintModel
from iopaint.schema import InpaintRequest
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
class InstructPix2Pix(DiffusionInpaintModel):
name = INSTRUCT_PIX2PIX_NAME
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers import StableDiffusionInstructPix2PixPipeline
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
)
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
"""
output = self.model(
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.sd_steps,
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.sd_guidance_scale,
output_type="np",
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -0,0 +1,65 @@
import PIL.Image
import cv2
import numpy as np
import torch
from iopaint.const import KANDINSKY22_NAME
from .base import DiffusionInpaintModel
from iopaint.schema import InpaintRequest
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
class Kandinsky(DiffusionInpaintModel):
pad_mod = 64
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers import AutoPipelineForInpainting
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {
"torch_dtype": torch_dtype,
"local_files_only": is_local_files_only(**kwargs),
}
self.model = AutoPipelineForInpainting.from_pretrained(
self.name, **model_kwargs
).to(device)
enable_low_mem(self.model, kwargs.get("low_mem", False))
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
generator = torch.manual_seed(config.sd_seed)
mask = mask.astype(np.float32) / 255
img_h, img_w = image.shape[:2]
# kandinsky 没有 strength
output = self.model(
prompt=config.prompt,
negative_prompt=config.negative_prompt,
image=PIL.Image.fromarray(image),
mask_image=mask[:, :, 0],
height=img_h,
width=img_w,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback_on_step_end=self.callback,
generator=generator,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
class Kandinsky22(Kandinsky):
name = KANDINSKY22_NAME

57
inpaint/model/lama.py Normal file
View File

@ -0,0 +1,57 @@
import os
import cv2
import numpy as np
import torch
from iopaint.helper import (
norm_img,
get_cache_path_by_url,
load_jit_model,
download_model,
)
from iopaint.schema import InpaintRequest
from .base import InpaintModel
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
class LaMa(InpaintModel):
name = "lama"
pad_mod = 8
is_erase_model = True
@staticmethod
def download():
download_model(LAMA_MODEL_URL, LAMA_MODEL_MD5)
def init_model(self, device, **kwargs):
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

336
inpaint/model/ldm.py Normal file
View File

@ -0,0 +1,336 @@
import os
import numpy as np
import torch
from loguru import logger
from .base import InpaintModel
from .ddim_sampler import DDIMSampler
from .plms_sampler import PLMSSampler
from iopaint.schema import InpaintRequest, LDMSampler
torch.manual_seed(42)
import torch.nn as nn
from iopaint.helper import (
download_model,
norm_img,
get_cache_path_by_url,
load_jit_model,
)
from .utils import (
make_beta_schedule,
timestep_embedding,
)
LDM_ENCODE_MODEL_URL = os.environ.get(
"LDM_ENCODE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
)
LDM_ENCODE_MODEL_MD5 = os.environ.get(
"LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
)
LDM_DECODE_MODEL_URL = os.environ.get(
"LDM_DECODE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
)
LDM_DECODE_MODEL_MD5 = os.environ.get(
"LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
)
LDM_DIFFUSION_MODEL_URL = os.environ.get(
"LDM_DIFFUSION_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
)
LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
"LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
)
class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
device,
timesteps=1000,
beta_schedule="linear",
linear_start=0.0015,
linear_end=0.0205,
cosine_s=0.008,
original_elbo_weight=0.0,
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.0,
parameterization="eps", # all assuming fixed variance schedules
use_positional_encodings=False,
):
super().__init__()
self.device = device
self.parameterization = parameterization
self.use_positional_encodings = use_positional_encodings
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
self.register_schedule(
beta_schedule=beta_schedule,
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
self.device,
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
class LatentDiffusion(DDPM):
def __init__(
self,
diffusion_model,
device,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
scale_factor=1.0,
scale_by_std=False,
*args,
**kwargs,
):
self.num_timesteps_cond = 1
self.scale_by_std = scale_by_std
super().__init__(device, *args, **kwargs)
self.diffusion_model = diffusion_model
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
self.num_downs = 2
self.scale_factor = scale_factor
def make_cond_schedule(
self,
):
self.cond_ids = torch.full(
size=(self.num_timesteps,),
fill_value=self.num_timesteps - 1,
dtype=torch.long,
)
ids = torch.round(
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
).long()
self.cond_ids[: self.num_timesteps_cond] = ids
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
super().register_schedule(
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def apply_model(self, x_noisy, t, cond):
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
return x_recon
class LDM(InpaintModel):
name = "ldm"
pad_mod = 32
is_erase_model = True
def __init__(self, device, fp16: bool = True, **kwargs):
self.fp16 = fp16
super().__init__(device)
self.device = device
def init_model(self, device, **kwargs):
self.diffusion_model = load_jit_model(
LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
)
self.cond_stage_model_decode = load_jit_model(
LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
)
self.cond_stage_model_encode = load_jit_model(
LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
)
if self.fp16 and "cuda" in str(device):
self.diffusion_model = self.diffusion_model.half()
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
self.model = LatentDiffusion(self.diffusion_model, device)
@staticmethod
def download():
download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5)
download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5)
download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
model_paths = [
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
]
return all([os.path.exists(it) for it in model_paths])
@torch.cuda.amp.autocast()
def forward(self, image, mask, config: InpaintRequest):
"""
image: [H, W, C] RGB
mask: [H, W, 1]
return: BGR IMAGE
"""
# image [1,3,512,512] float32
# mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32
if config.ldm_sampler == LDMSampler.ddim:
sampler = DDIMSampler(self.model)
elif config.ldm_sampler == LDMSampler.plms:
sampler = PLMSSampler(self.model)
else:
raise ValueError()
steps = config.ldm_steps
image = norm_img(image)
mask = norm_img(mask)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image
mask = self._norm(mask)
masked_image = self._norm(masked_image)
c = self.cond_stage_model_encode(masked_image)
torch.cuda.empty_cache()
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = sampler.sample(
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
)
torch.cuda.empty_cache()
x_samples_ddim = self.cond_stage_model_decode(
samples_ddim
) # samples_ddim: 1, 3, 128, 128 float32
torch.cuda.empty_cache()
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
# inpainted = (1 - mask) * image + mask * predicted_image
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
return inpainted_image
def _norm(self, tensor):
return tensor * 2.0 - 1.0

97
inpaint/model/manga.py Normal file
View File

@ -0,0 +1,97 @@
import os
import random
import cv2
import numpy as np
import torch
import time
from loguru import logger
from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model
from .base import InpaintModel
from iopaint.schema import InpaintRequest
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
"MANGA_INPAINTOR_MODEL_URL",
"https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
)
MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
"MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
)
MANGA_LINE_MODEL_URL = os.environ.get(
"MANGA_LINE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/manga/erika.jit",
)
MANGA_LINE_MODEL_MD5 = os.environ.get(
"MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
)
class Manga(InpaintModel):
name = "manga"
pad_mod = 16
is_erase_model = True
def init_model(self, device, **kwargs):
self.inpaintor_model = load_jit_model(
MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
)
self.line_model = load_jit_model(
MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
)
self.seed = 42
@staticmethod
def download():
download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5)
download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
model_paths = [
get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
get_cache_path_by_url(MANGA_LINE_MODEL_URL),
]
return all([os.path.exists(it) for it in model_paths])
def forward(self, image, mask, config: InpaintRequest):
"""
image: [H, W, C] RGB
mask: [H, W, 1]
return: BGR IMAGE
"""
seed = self.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
gray_img = torch.from_numpy(
gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
).to(self.device)
start = time.time()
lines = self.line_model(gray_img)
torch.cuda.empty_cache()
lines = torch.clamp(lines, 0, 255)
logger.info(f"erika_model time: {time.time() - start}")
mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
mask = mask.permute(0, 3, 1, 2)
mask = torch.where(mask > 0.5, 1.0, 0.0)
noise = torch.randn_like(mask)
ones = torch.ones_like(mask)
gray_img = gray_img / 255 * 2 - 1.0
lines = lines / 255 * 2 - 1.0
start = time.time()
inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
logger.info(f"image_inpaintor_model time: {time.time() - start}")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
return cur_res

1945
inpaint/model/mat.py Normal file

File diff suppressed because it is too large Load Diff

110
inpaint/model/mi_gan.py Normal file
View File

@ -0,0 +1,110 @@
import os
import cv2
import torch
from iopaint.helper import (
load_jit_model,
download_model,
get_cache_path_by_url,
boxes_from_mask,
resize_max_size,
norm_img,
)
from .base import InpaintModel
from iopaint.schema import InpaintRequest
MIGAN_MODEL_URL = os.environ.get(
"MIGAN_MODEL_URL",
"https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
)
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
class MIGAN(InpaintModel):
name = "migan"
min_size = 512
pad_mod = 512
pad_to_square = True
is_erase_model = True
def init_model(self, device, **kwargs):
self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()
@staticmethod
def download():
download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
@torch.no_grad()
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
if image.shape[0] == 512 and image.shape[1] == 512:
return self._pad_forward(image, mask, config)
boxes = boxes_from_mask(mask)
crop_result = []
config.hd_strategy_crop_margin = 128
for box in boxes:
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
origin_size = crop_image.shape[:2]
resize_image = resize_max_size(crop_image, size_limit=512)
resize_mask = resize_max_size(crop_mask, size_limit=512)
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = crop_mask < 127
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
original_pixel_indices
]
crop_result.append((inpaint_result, crop_box))
inpaint_result = image[:, :, ::-1].copy()
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
return inpaint_result
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W] mask area == 255
return: BGR IMAGE
"""
image = norm_img(image) # [0, 1]
image = image * 2 - 1 # [0, 1] -> [-1, 1]
mask = (mask > 120) * 255
mask = norm_img(mask)
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
erased_img = image * (1 - mask)
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
output = self.model(input_image)
output = (
(output.permute(0, 2, 3, 1) * 127.5 + 127.5)
.round()
.clamp(0, 255)
.to(torch.uint8)
)
output = output[0].cpu().numpy()
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return cur_res

29
inpaint/model/opencv2.py Normal file
View File

@ -0,0 +1,29 @@
import cv2
from .base import InpaintModel
from iopaint.schema import InpaintRequest
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
class OpenCV2(InpaintModel):
name = "cv2"
pad_mod = 1
is_erase_model = True
@staticmethod
def is_downloaded() -> bool:
return True
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1]
return: BGR IMAGE
"""
cur_res = cv2.inpaint(
image[:, :, ::-1],
mask,
inpaintRadius=config.cv2_radius,
flags=flag_map[config.cv2_flag],
)
return cur_res

View File

@ -0,0 +1,19 @@
from pathlib import Path
from typing import Dict
CURRENT_DIR = Path(__file__).parent.absolute()
def get_config_files() -> Dict[str, Path]:
"""
- `v1`: Config file for Stable Diffusion v1
- `v2`: Config file for Stable Diffusion v2
- `xl`: Config file for Stable Diffusion XL
- `xl_refiner`: Config file for Stable Diffusion XL Refiner
"""
return {
"v1": CURRENT_DIR / "v1-inference.yaml",
"v2": CURRENT_DIR / "v2-inference-v.yaml",
"xl": CURRENT_DIR / "sd_xl_base.yaml",
"xl_refiner": CURRENT_DIR / "sd_xl_refiner.yaml",
}

View File

@ -0,0 +1,93 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: [1, 2, 10]
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,86 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280]
spatial_transformer_attn_type: softmax-xformers
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,70 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -0,0 +1,68 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -0,0 +1,68 @@
import PIL
import PIL.Image
import cv2
import torch
from loguru import logger
from iopaint.helper import decode_base64_to_image
from .base import DiffusionInpaintModel
from iopaint.schema import InpaintRequest
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
class PaintByExample(DiffusionInpaintModel):
name = "Fantasy-Studio/Paint-by-Example"
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from diffusers import DiffusionPipeline
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {
"local_files_only": is_local_files_only(**kwargs),
}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
model_kwargs.update(
dict(safety_checker=None, requires_safety_checker=False)
)
self.model = DiffusionPipeline.from_pretrained(
self.name, torch_dtype=torch_dtype, **model_kwargs
)
enable_low_mem(self.model, kwargs.get("low_mem", False))
# TODO: gpu_id
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
if config.paint_by_example_example_image is None:
raise ValueError("paint_by_example_example_image is required")
example_image, _, _ = decode_base64_to_image(
config.paint_by_example_example_image
)
output = self.model(
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=PIL.Image.fromarray(example_image),
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
output_type="np.array",
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -0,0 +1,225 @@
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
import torch
import numpy as np
from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
from tqdm import tqdm
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
steps,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
return img
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,101 @@
from PIL import Image
import PIL.Image
import cv2
import torch
from loguru import logger
from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..utils import (
handle_from_pretrained_exceptions,
get_torch_dtype,
enable_low_mem,
is_local_files_only,
)
from iopaint.schema import InpaintRequest
from .powerpaint_tokenizer import add_task_to_prompt
from ...const import POWERPAINT_NAME
class PowerPaint(DiffusionInpaintModel):
name = POWERPAINT_NAME
pad_mod = 8
min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
def init_model(self, device: torch.device, **kwargs):
from .pipeline_powerpaint import StableDiffusionInpaintPipeline
from .powerpaint_tokenizer import PowerPaintTokenizer
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
self.model = handle_from_pretrained_exceptions(
StableDiffusionInpaintPipeline.from_pretrained,
pretrained_model_name_or_path=self.name,
variant="fp16",
torch_dtype=torch_dtype,
**model_kwargs,
)
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
img_h, img_w = image.shape[:2]
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
config.prompt, config.negative_prompt, config.powerpaint_task
)
output = self.model(
image=PIL.Image.fromarray(image),
promptA=promptA,
promptB=promptB,
tradoff=config.fitting_degree,
tradoff_nag=config.fitting_degree,
negative_promptA=negative_promptA,
negative_promptB=negative_promptB,
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
num_inference_steps=config.sd_steps,
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -0,0 +1,186 @@
from itertools import chain
import PIL.Image
import cv2
import torch
from iopaint.model.original_sd_configs import get_config_files
from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..utils import (
get_torch_dtype,
enable_low_mem,
is_local_files_only,
handle_from_pretrained_exceptions,
)
from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest, ModelType
from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel_forward
from .v2.unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
class PowerPaintV2(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
hf_model_id = "Sanster/PowerPaint_v2"
def init_model(self, device: torch.device, **kwargs):
from .v2.pipeline_PowerPaint_Brushnet_CA import (
StableDiffusionPowerPaintBrushNetPipeline,
)
from .powerpaint_tokenizer import PowerPaintTokenizer
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
text_encoder_brushnet = CLIPTextModel.from_pretrained(
self.hf_model_id,
subfolder="text_encoder_brushnet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained(
self.hf_model_id,
subfolder="PowerPaint_Brushnet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=False,
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
**model_kwargs,
)
else:
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
**model_kwargs,
)
pipe.tokenizer = PowerPaintTokenizer(
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
)
self.model = pipe
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
self.model.unet, self.model.unet.__class__
)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in chain(
self.model.unet.down_blocks, self.model.brushnet.down_blocks
):
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
image = image * (1 - mask / 255.0)
img_h, img_w = image.shape[:2]
image = PIL.Image.fromarray(image.astype(np.uint8))
mask = PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB")
promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(
config.powerpaint_task
)
output = self.model(
image=image,
mask=mask,
promptA=promptA,
promptB=promptB,
promptU=config.prompt,
tradoff=config.fitting_degree,
tradoff_nag=config.fitting_degree,
negative_promptA=negative_promptA,
negative_promptB=negative_promptB,
negative_promptU=config.negative_prompt,
num_inference_steps=config.sd_steps,
# strength=config.sd_strength,
brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -0,0 +1,254 @@
import copy
import random
from typing import Any, List, Union
from transformers import CLIPTokenizer
from iopaint.schema import PowerPaintTask
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
if task == PowerPaintTask.object_remove:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
elif task == PowerPaintTask.context_aware:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.shape_guided:
promptA = prompt + " P_shape"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.outpainting:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
else:
promptA = prompt + " P_obj"
promptB = prompt + " P_obj"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
return promptA, promptB, negative_promptA, negative_promptB
def task_to_prompt(task: PowerPaintTask):
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
"", "", task
)
return (
promptA.strip(),
promptB.strip(),
negative_promptA.strip(),
negative_promptB.strip(),
)
class PowerPaintTokenizer:
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
num_vec_per_token = 10
for placeholder_token in placeholder_tokens:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
output.append(ith_token)
self.token_map[placeholder_token] = output
def __getattr__(self, name: str) -> Any:
if name == "wrapped":
return super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
try:
return super().__getattr__(name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self,
text: Union[str, List[str]],
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(
self.replace_placeholder_tokens_in_text(
text[i], vector_shuffle=vector_shuffle
)
)
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(
self, text: Union[str, List[str]]
) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(
self, token_ids, return_raw: bool = False, *args, **kwargs
) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,342 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import apply_freeu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
down_block_add_samples: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(
0
) # todo: add before or after
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def DownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
down_block_add_samples: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(
0
) # todo: add before or after
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
return_res_samples: Optional[bool] = False,
up_block_add_samples: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
if return_res_samples:
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0)
if return_res_samples:
return hidden_states, output_states
else:
return hidden_states
def UpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
return_res_samples: Optional[bool] = False,
up_block_add_samples: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
if return_res_samples:
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(
0
) # todo: add before or after
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(
0
) # todo: add before or after
if return_res_samples:
return hidden_states, output_states
else:
return hidden_states

View File

@ -0,0 +1,402 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def UNet2DConditionModel_forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
through the `self.time_embedding` layer to obtain the timestep embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (
1 - encoder_attention_mask.to(sample.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
aug_emb = self.get_aug_embed(
emb=emb,
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
)
if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
encoder_hidden_states = self.process_encoder_hidden_states(
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
)
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if (
cross_attention_kwargs is not None
and cross_attention_kwargs.get("gligen", None) is not None
):
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = (
mid_block_additional_residual is not None
and down_block_additional_residuals is not None
)
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
is_brushnet = (
down_block_add_samples is not None
and mid_block_add_sample is not None
and up_block_add_samples is not None
)
if (
not is_adapter
and mid_block_additional_residual is None
and down_block_additional_residuals is not None
):
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True
down_block_res_samples = (sample,)
if is_brushnet:
sample = sample + down_block_add_samples.pop(0)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = (
down_intrablock_additional_residuals.pop(0)
)
if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [
down_block_add_samples.pop(0)
for _ in range(
len(downsample_block.resnets)
+ (downsample_block.downsamplers != None)
)
]
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
additional_residuals = {}
if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [
down_block_add_samples.pop(0)
for _ in range(
len(downsample_block.resnets)
+ (downsample_block.downsamplers != None)
)
]
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
scale=lora_scale,
**additional_residuals,
)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
down_block_res_samples += res_samples
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = (
down_block_res_sample + down_block_additional_residual
)
new_down_block_res_samples = new_down_block_res_samples + (
down_block_res_sample,
)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_intrablock_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
if is_brushnet:
sample = sample + mid_block_add_sample
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [
up_block_add_samples.pop(0)
for _ in range(
len(upsample_block.resnets)
+ (upsample_block.upsamplers != None)
)
]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [
up_block_add_samples.pop(0)
for _ in range(
len(upsample_block.resnets)
+ (upsample_block.upsamplers != None)
)
]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
**additional_residuals,
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)

Some files were not shown because too many files have changed in this diff Show More