removed iopaint folder
This commit is contained in:
parent
70af4845af
commit
235377d738
@ -1,23 +0,0 @@
|
||||
import os
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
|
||||
os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
|
||||
os.environ["LRU_CACHE_CAPACITY"] = "1"
|
||||
# prevent CPU memory leak when run model on GPU
|
||||
# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
|
||||
# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
|
||||
os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
|
||||
|
||||
def entry_point():
|
||||
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
||||
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
||||
from iopaint.cli import typer_app
|
||||
|
||||
typer_app()
|
@ -1,4 +0,0 @@
|
||||
from iopaint import entry_point
|
||||
|
||||
if __name__ == "__main__":
|
||||
entry_point()
|
398
iopaint/api.py
398
iopaint/api.py
@ -1,398 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import socketio
|
||||
import torch
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, FileResponse, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from loguru import logger
|
||||
from socketio import AsyncServer
|
||||
|
||||
from iopaint.file_manager import FileManager
|
||||
from iopaint.helper import (
|
||||
load_img,
|
||||
decode_base64_to_image,
|
||||
pil_to_bytes,
|
||||
numpy_to_bytes,
|
||||
concat_alpha_channel,
|
||||
gen_frontend_mask,
|
||||
adjust_mask,
|
||||
)
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.remove_bg import RemoveBG
|
||||
from iopaint.schema import (
|
||||
GenInfoResponse,
|
||||
ApiConfig,
|
||||
ServerConfigResponse,
|
||||
SwitchModelRequest,
|
||||
InpaintRequest,
|
||||
RunPluginRequest,
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
AdjustMaskRequest,
|
||||
RemoveBGModel,
|
||||
SwitchPluginModelRequest,
|
||||
ModelInfo,
|
||||
InteractiveSegModel,
|
||||
RealESRGANModel,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
WEB_APP_DIR = CURRENT_DIR / "web_app"
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = False
|
||||
try:
|
||||
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
rich_available = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get("detail", ""),
|
||||
"body": vars(e).get("body", ""),
|
||||
"errors": str(e),
|
||||
}
|
||||
if not isinstance(
|
||||
e, HTTPException
|
||||
): # do not print backtrace on known httpexceptions
|
||||
message = f"API error: {request.method}: {request.url} {err}"
|
||||
if rich_available:
|
||||
print(message)
|
||||
console.print_exception(
|
||||
show_locals=True,
|
||||
max_frames=2,
|
||||
extra_lines=1,
|
||||
suppress=[anyio, starlette],
|
||||
word_wrap=False,
|
||||
width=min([console.width, 200]),
|
||||
)
|
||||
else:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(
|
||||
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, e: HTTPException):
|
||||
return handle_exception(request, e)
|
||||
|
||||
cors_options = {
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"expose_headers": ["X-Seed"],
|
||||
}
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
|
||||
global_sio: AsyncServer = None
|
||||
|
||||
|
||||
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
|
||||
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
||||
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
||||
|
||||
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
||||
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
||||
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
||||
return {}
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, config: ApiConfig):
|
||||
self.app = app
|
||||
self.config = config
|
||||
self.router = APIRouter()
|
||||
self.queue_lock = threading.Lock()
|
||||
api_middleware(self.app)
|
||||
|
||||
self.file_manager = self._build_file_manager()
|
||||
self.plugins = self._build_plugins()
|
||||
self.model_manager = self._build_model_manager()
|
||||
|
||||
# fmt: off
|
||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
|
||||
response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
|
||||
global global_sio
|
||||
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
||||
self.app.mount("/ws", self.combined_asgi_app)
|
||||
global_sio = self.sio
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def api_save_image(self, file: UploadFile):
|
||||
filename = file.filename
|
||||
origin_image_bytes = file.file.read()
|
||||
with open(self.config.output_dir / filename, "wb") as fw:
|
||||
fw.write(origin_image_bytes)
|
||||
|
||||
def api_current_model(self) -> ModelInfo:
|
||||
return self.model_manager.current_model
|
||||
|
||||
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
|
||||
if req.name == self.model_manager.name:
|
||||
return self.model_manager.current_model
|
||||
self.model_manager.switch(req.name)
|
||||
return self.model_manager.current_model
|
||||
|
||||
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
||||
if req.plugin_name in self.plugins:
|
||||
self.plugins[req.plugin_name].switch_model(req.model_name)
|
||||
if req.plugin_name == RemoveBG.name:
|
||||
self.config.remove_bg_model = req.model_name
|
||||
if req.plugin_name == RealESRGANUpscaler.name:
|
||||
self.config.realesrgan_model = req.model_name
|
||||
if req.plugin_name == InteractiveSeg.name:
|
||||
self.config.interactive_seg_model = req.model_name
|
||||
torch_gc()
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
plugins = []
|
||||
for it in self.plugins.values():
|
||||
plugins.append(
|
||||
PluginInfo(
|
||||
name=it.name,
|
||||
support_gen_image=it.support_gen_image,
|
||||
support_gen_mask=it.support_gen_mask,
|
||||
)
|
||||
)
|
||||
|
||||
return ServerConfigResponse(
|
||||
plugins=plugins,
|
||||
modelInfos=self.model_manager.scan_models(),
|
||||
removeBGModel=self.config.remove_bg_model,
|
||||
removeBGModels=RemoveBGModel.values(),
|
||||
realesrganModel=self.config.realesrgan_model,
|
||||
realesrganModels=RealESRGANModel.values(),
|
||||
interactiveSegModel=self.config.interactive_seg_model,
|
||||
interactiveSegModels=InteractiveSegModel.values(),
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
controlnetMethod=self.model_manager.controlnet_method,
|
||||
disableModelSwitch=False,
|
||||
isDesktop=False,
|
||||
samplers=self.api_samplers(),
|
||||
)
|
||||
|
||||
def api_input_image(self) -> FileResponse:
|
||||
if self.config.input and self.config.input.is_file():
|
||||
return FileResponse(self.config.input)
|
||||
raise HTTPException(status_code=404, detail="Input image not found")
|
||||
|
||||
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
|
||||
_, _, info = load_img(file.file.read(), return_info=True)
|
||||
parts = info.get("parameters", "").split("Negative prompt: ")
|
||||
prompt = parts[0].strip()
|
||||
negative_prompt = ""
|
||||
if len(parts) > 1:
|
||||
negative_prompt = parts[1].split("\n")[0].strip()
|
||||
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
||||
|
||||
def api_inpaint(self, req: InpaintRequest):
|
||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
raise HTTPException(
|
||||
400,
|
||||
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
||||
)
|
||||
|
||||
if req.paint_by_example_example_image:
|
||||
paint_by_example_image, _, _ = decode_base64_to_image(
|
||||
req.paint_by_example_example_image
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
rgb_np_img = self.model_manager(image, mask, req)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
||||
torch_gc()
|
||||
|
||||
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
||||
|
||||
ext = "png"
|
||||
res_img_bytes = pil_to_bytes(
|
||||
Image.fromarray(rgb_res),
|
||||
ext=ext,
|
||||
quality=self.config.quality,
|
||||
infos=infos,
|
||||
)
|
||||
|
||||
asyncio.run(self.sio.emit("diffusion_finish"))
|
||||
|
||||
return Response(
|
||||
content=res_img_bytes,
|
||||
media_type=f"image/{ext}",
|
||||
headers={"X-Seed": str(req.sd_seed)},
|
||||
)
|
||||
|
||||
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
||||
ext = "png"
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||
if not self.plugins[req.name].support_gen_image:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
||||
torch_gc()
|
||||
|
||||
if bgr_or_rgba_np_img.shape[2] == 4:
|
||||
rgba_np_img = bgr_or_rgba_np_img
|
||||
else:
|
||||
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
||||
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
||||
|
||||
return Response(
|
||||
content=pil_to_bytes(
|
||||
Image.fromarray(rgba_np_img),
|
||||
ext=ext,
|
||||
quality=self.config.quality,
|
||||
infos=infos,
|
||||
),
|
||||
media_type=f"image/{ext}",
|
||||
)
|
||||
|
||||
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||
if not self.plugins[req.name].support_gen_mask:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
||||
torch_gc()
|
||||
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
||||
return Response(
|
||||
content=numpy_to_bytes(res_mask, "png"),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
||||
def api_samplers(self) -> List[str]:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||
|
||||
def launch(self):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
self.combined_asgi_app,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
timeout_keep_alive=999999999,
|
||||
)
|
||||
|
||||
def _build_file_manager(self) -> Optional[FileManager]:
|
||||
if self.config.input and self.config.input.is_dir():
|
||||
logger.info(
|
||||
f"Input is directory, initialize file manager {self.config.input}"
|
||||
)
|
||||
|
||||
return FileManager(
|
||||
app=self.app,
|
||||
input_dir=self.config.input,
|
||||
mask_dir=self.config.mask_dir,
|
||||
output_dir=self.config.output_dir,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
||||
return build_plugins(
|
||||
self.config.enable_interactive_seg,
|
||||
self.config.interactive_seg_model,
|
||||
self.config.interactive_seg_device,
|
||||
self.config.enable_remove_bg,
|
||||
self.config.remove_bg_model,
|
||||
self.config.enable_anime_seg,
|
||||
self.config.enable_realesrgan,
|
||||
self.config.realesrgan_device,
|
||||
self.config.realesrgan_model,
|
||||
self.config.enable_gfpgan,
|
||||
self.config.gfpgan_device,
|
||||
self.config.enable_restoreformer,
|
||||
self.config.restoreformer_device,
|
||||
self.config.no_half,
|
||||
)
|
||||
|
||||
def _build_model_manager(self):
|
||||
return ModelManager(
|
||||
name=self.config.model,
|
||||
device=torch.device(self.config.device),
|
||||
no_half=self.config.no_half,
|
||||
low_mem=self.config.low_mem,
|
||||
disable_nsfw=self.config.disable_nsfw_checker,
|
||||
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||
local_files_only=self.config.local_files_only,
|
||||
cpu_offload=self.config.cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
@ -1,128 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TimeElapsedColumn,
|
||||
MofNCompleteColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
|
||||
from iopaint.helper import pil_to_bytes
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
|
||||
def glob_images(path: Path) -> Dict[str, Path]:
|
||||
# png/jpg/jpeg
|
||||
if path.is_file():
|
||||
return {path.stem: path}
|
||||
elif path.is_dir():
|
||||
res = {}
|
||||
for it in path.glob("*.*"):
|
||||
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
||||
res[it.stem] = it
|
||||
return res
|
||||
|
||||
|
||||
def batch_inpaint(
|
||||
model: str,
|
||||
device,
|
||||
image: Path,
|
||||
mask: Path,
|
||||
output: Path,
|
||||
config: Optional[Path] = None,
|
||||
concat: bool = False,
|
||||
):
|
||||
if image.is_dir() and output.is_file():
|
||||
logger.error(
|
||||
"invalid --output: when image is a directory, output should be a directory"
|
||||
)
|
||||
exit(-1)
|
||||
output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_paths = glob_images(image)
|
||||
mask_paths = glob_images(mask)
|
||||
if len(image_paths) == 0:
|
||||
logger.error("invalid --image: empty image folder")
|
||||
exit(-1)
|
||||
if len(mask_paths) == 0:
|
||||
logger.error("invalid --mask: empty mask folder")
|
||||
exit(-1)
|
||||
|
||||
if config is None:
|
||||
inpaint_request = InpaintRequest()
|
||||
logger.info(f"Using default config: {inpaint_request}")
|
||||
else:
|
||||
with open(config, "r", encoding="utf-8") as f:
|
||||
inpaint_request = InpaintRequest(**json.load(f))
|
||||
logger.info(f"Using config: {inpaint_request}")
|
||||
|
||||
model_manager = ModelManager(name=model, device=device)
|
||||
first_mask = list(mask_paths.values())[0]
|
||||
|
||||
console = Console()
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("Batch processing...", total=len(image_paths))
|
||||
for stem, image_p in image_paths.items():
|
||||
if stem not in mask_paths and mask.is_dir():
|
||||
progress.log(f"mask for {image_p} not found")
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
mask_p = mask_paths.get(stem, first_mask)
|
||||
|
||||
infos = Image.open(image_p).info
|
||||
|
||||
img = np.array(Image.open(image_p).convert("RGB"))
|
||||
mask_img = np.array(Image.open(mask_p).convert("L"))
|
||||
|
||||
if mask_img.shape[:2] != img.shape[:2]:
|
||||
progress.log(
|
||||
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
|
||||
)
|
||||
mask_img = cv2.resize(
|
||||
mask_img,
|
||||
(img.shape[1], img.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
mask_img[mask_img >= 127] = 255
|
||||
mask_img[mask_img < 127] = 0
|
||||
|
||||
# bgr
|
||||
inpaint_result = model_manager(img, mask_img, inpaint_request)
|
||||
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
|
||||
if concat:
|
||||
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
|
||||
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
|
||||
|
||||
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
|
||||
save_p = output / f"{stem}.png"
|
||||
with open(save_p, "wb") as fw:
|
||||
fw.write(img_bytes)
|
||||
|
||||
progress.update(task, advance=1)
|
||||
torch_gc()
|
||||
# pid = psutil.Process().pid
|
||||
# memory_info = psutil.Process(pid).memory_info()
|
||||
# memory_in_mb = memory_info.rss / (1024 * 1024)
|
||||
# print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")
|
@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import nvidia_smi
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
NUM_THREADS = str(4)
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
if os.environ.get("CACHE_DIR"):
|
||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
||||
|
||||
|
||||
def run_model(model, size):
|
||||
# RGB
|
||||
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
||||
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
||||
|
||||
config = InpaintRequest(
|
||||
ldm_steps=2,
|
||||
hd_strategy=HDStrategy.ORIGINAL,
|
||||
hd_strategy_crop_margin=128,
|
||||
hd_strategy_crop_trigger_size=128,
|
||||
hd_strategy_resize_limit=128,
|
||||
prompt="a fox is sitting on a bench",
|
||||
sd_steps=5,
|
||||
sd_sampler=SDSampler.ddim,
|
||||
)
|
||||
model(image, mask, config)
|
||||
|
||||
|
||||
def benchmark(model, times: int, empty_cache: bool):
|
||||
sizes = [(512, 512)]
|
||||
|
||||
nvidia_smi.nvmlInit()
|
||||
device_id = 0
|
||||
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
|
||||
|
||||
def format(metrics):
|
||||
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
# 每个 size 给出显存和内存占用的指标
|
||||
for size in sizes:
|
||||
torch.cuda.empty_cache()
|
||||
time_metrics = []
|
||||
cpu_metrics = []
|
||||
memory_metrics = []
|
||||
gpu_memory_metrics = []
|
||||
for _ in range(times):
|
||||
start = time.time()
|
||||
run_model(model, size)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# cpu_metrics.append(process.cpu_percent())
|
||||
time_metrics.append((time.time() - start) * 1000)
|
||||
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
||||
gpu_memory_metrics.append(
|
||||
nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
|
||||
)
|
||||
|
||||
print(f"size: {size}".center(80, "-"))
|
||||
# print(f"cpu: {format(cpu_metrics)}")
|
||||
print(f"latency: {format(time_metrics)}ms")
|
||||
print(f"memory: {format(memory_metrics)} MB")
|
||||
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
|
||||
|
||||
nvidia_smi.nvmlShutdown()
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name")
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--times", default=10, type=int)
|
||||
parser.add_argument("--empty-cache", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser()
|
||||
device = torch.device(args.device)
|
||||
model = ModelManager(
|
||||
name=args.name,
|
||||
device=device,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
)
|
||||
benchmark(model, args.times, args.empty_cache)
|
232
iopaint/cli.py
232
iopaint/cli.py
@ -1,232 +0,0 @@
|
||||
import webbrowser
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from fastapi import FastAPI
|
||||
from loguru import logger
|
||||
from typer import Option
|
||||
from typer_config import use_json_config
|
||||
|
||||
from iopaint.const import *
|
||||
from iopaint.runtime import setup_model_dir, dump_environment_info, check_device
|
||||
from iopaint.schema import InteractiveSegModel, Device, RealESRGANModel, RemoveBGModel
|
||||
|
||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
||||
|
||||
|
||||
@typer_app.command(help="Install all plugins dependencies")
|
||||
def install_plugins_packages():
|
||||
from iopaint.installer import install_plugins_package
|
||||
|
||||
install_plugins_package()
|
||||
|
||||
|
||||
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
|
||||
def download(
|
||||
model: str = Option(
|
||||
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
|
||||
),
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR,
|
||||
help=MODEL_DIR_HELP,
|
||||
file_okay=False,
|
||||
callback=setup_model_dir,
|
||||
),
|
||||
):
|
||||
from iopaint.download import cli_download_model
|
||||
|
||||
cli_download_model(model)
|
||||
|
||||
|
||||
@typer_app.command(name="list", help="List downloaded models")
|
||||
def list_model(
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR,
|
||||
help=MODEL_DIR_HELP,
|
||||
file_okay=False,
|
||||
callback=setup_model_dir,
|
||||
),
|
||||
):
|
||||
from iopaint.download import scan_models
|
||||
|
||||
scanned_models = scan_models()
|
||||
for it in scanned_models:
|
||||
print(it.name)
|
||||
|
||||
|
||||
@typer_app.command(help="Batch processing images")
|
||||
def run(
|
||||
model: str = Option("lama"),
|
||||
device: Device = Option(Device.cpu),
|
||||
image: Path = Option(..., help="Image folders or file path"),
|
||||
mask: Path = Option(
|
||||
...,
|
||||
help="Mask folders or file path. "
|
||||
"If it is a directory, the mask images in the directory should have the same name as the original image."
|
||||
"If it is a file, all images will use this mask."
|
||||
"Mask will automatically resize to the same size as the original image.",
|
||||
),
|
||||
output: Path = Option(..., help="Output directory or file path"),
|
||||
config: Path = Option(
|
||||
None, help="Config file path. You can use dump command to create a base config."
|
||||
),
|
||||
concat: bool = Option(
|
||||
False, help="Concat original image, mask and output images into one image"
|
||||
),
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR,
|
||||
help=MODEL_DIR_HELP,
|
||||
file_okay=False,
|
||||
callback=setup_model_dir,
|
||||
),
|
||||
):
|
||||
from iopaint.download import cli_download_model, scan_models
|
||||
|
||||
scanned_models = scan_models()
|
||||
if model not in [it.name for it in scanned_models]:
|
||||
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
||||
cli_download_model(model)
|
||||
|
||||
from iopaint.batch_processing import batch_inpaint
|
||||
|
||||
batch_inpaint(model, device, image, mask, output, config, concat)
|
||||
|
||||
|
||||
@typer_app.command(help="Start IOPaint server")
|
||||
@use_json_config()
|
||||
def start(
|
||||
host: str = Option("127.0.0.1"),
|
||||
port: int = Option(8080),
|
||||
inbrowser: bool = Option(False, help=INBROWSER_HELP),
|
||||
model: str = Option(
|
||||
DEFAULT_MODEL,
|
||||
help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
|
||||
f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
|
||||
),
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR,
|
||||
help=MODEL_DIR_HELP,
|
||||
dir_okay=True,
|
||||
file_okay=False,
|
||||
callback=setup_model_dir,
|
||||
),
|
||||
low_mem: bool = Option(False, help=LOW_MEM_HELP),
|
||||
no_half: bool = Option(False, help=NO_HALF_HELP),
|
||||
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
||||
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
||||
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
|
||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||
device: Device = Option(Device.cpu),
|
||||
input: Optional[Path] = Option(None, help=INPUT_HELP),
|
||||
mask_dir: Optional[Path] = Option(
|
||||
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
output_dir: Optional[Path] = Option(
|
||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
quality: int = Option(95, help=QUALITY_HELP),
|
||||
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
|
||||
interactive_seg_model: InteractiveSegModel = Option(
|
||||
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
|
||||
),
|
||||
interactive_seg_device: Device = Option(Device.cpu),
|
||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
|
||||
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
||||
enable_realesrgan: bool = Option(False),
|
||||
realesrgan_device: Device = Option(Device.cpu),
|
||||
realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
|
||||
enable_gfpgan: bool = Option(False),
|
||||
gfpgan_device: Device = Option(Device.cpu),
|
||||
enable_restoreformer: bool = Option(False),
|
||||
restoreformer_device: Device = Option(Device.cpu),
|
||||
):
|
||||
dump_environment_info()
|
||||
device = check_device(device)
|
||||
if input and not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit(-1)
|
||||
if mask_dir and not mask_dir.exists():
|
||||
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
|
||||
exit(-1)
|
||||
if input and input.is_dir() and not output_dir:
|
||||
logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
|
||||
exit(-1)
|
||||
if output_dir:
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
logger.info(f"Image will be saved to {output_dir}")
|
||||
if not output_dir.exists():
|
||||
logger.info(f"Create output directory {output_dir}")
|
||||
output_dir.mkdir(parents=True)
|
||||
if mask_dir:
|
||||
mask_dir = mask_dir.expanduser().absolute()
|
||||
|
||||
model_dir = model_dir.expanduser().absolute()
|
||||
|
||||
if local_files_only:
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
from iopaint.download import cli_download_model, scan_models
|
||||
|
||||
scanned_models = scan_models()
|
||||
if model not in [it.name for it in scanned_models]:
|
||||
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
||||
cli_download_model(model)
|
||||
|
||||
from iopaint.api import Api
|
||||
from iopaint.schema import ApiConfig
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if inbrowser:
|
||||
webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
api_config = ApiConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
inbrowser=inbrowser,
|
||||
model=model,
|
||||
no_half=no_half,
|
||||
low_mem=low_mem,
|
||||
cpu_offload=cpu_offload,
|
||||
disable_nsfw_checker=disable_nsfw_checker,
|
||||
local_files_only=local_files_only,
|
||||
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
||||
device=device,
|
||||
input=input,
|
||||
mask_dir=mask_dir,
|
||||
output_dir=output_dir,
|
||||
quality=quality,
|
||||
enable_interactive_seg=enable_interactive_seg,
|
||||
interactive_seg_model=interactive_seg_model,
|
||||
interactive_seg_device=interactive_seg_device,
|
||||
enable_remove_bg=enable_remove_bg,
|
||||
remove_bg_model=remove_bg_model,
|
||||
enable_anime_seg=enable_anime_seg,
|
||||
enable_realesrgan=enable_realesrgan,
|
||||
realesrgan_device=realesrgan_device,
|
||||
realesrgan_model=realesrgan_model,
|
||||
enable_gfpgan=enable_gfpgan,
|
||||
gfpgan_device=gfpgan_device,
|
||||
enable_restoreformer=enable_restoreformer,
|
||||
restoreformer_device=restoreformer_device,
|
||||
)
|
||||
print(api_config.model_dump_json(indent=4))
|
||||
api = Api(app, api_config)
|
||||
api.launch()
|
||||
|
||||
|
||||
@typer_app.command(help="Start IOPaint web config page")
|
||||
def start_web_config(
|
||||
config_file: Path = Option("config.json"),
|
||||
):
|
||||
dump_environment_info()
|
||||
from iopaint.web_config import main
|
||||
|
||||
main(config_file)
|
128
iopaint/const.py
128
iopaint/const.py
@ -1,128 +0,0 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
||||
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||
ANYTEXT_NAME = "Sanster/AnyText"
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
||||
|
||||
MPS_UNSUPPORT_MODELS = [
|
||||
"lama",
|
||||
"ldm",
|
||||
"zits",
|
||||
"mat",
|
||||
"fcf",
|
||||
"cv2",
|
||||
"manga",
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "lama"
|
||||
AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
|
||||
DIFFUSION_MODELS = [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
|
||||
"redstonehero/dreamshaper-inpainting",
|
||||
"Sanster/anything-4.0-inpainting",
|
||||
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
"Fantasy-Studio/Paint-by-Example",
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
]
|
||||
|
||||
NO_HALF_HELP = """
|
||||
Using full precision(fp32) model.
|
||||
If your diffusion model generate result is always black or green, use this argument.
|
||||
"""
|
||||
|
||||
CPU_OFFLOAD_HELP = """
|
||||
Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
|
||||
"""
|
||||
|
||||
LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory."
|
||||
|
||||
DISABLE_NSFW_HELP = """
|
||||
Disable NSFW checker for diffusion model.
|
||||
"""
|
||||
|
||||
CPU_TEXTENCODER_HELP = """
|
||||
Run diffusion models text encoder on CPU to reduce vRAM usage.
|
||||
"""
|
||||
|
||||
SD_CONTROLNET_CHOICES: List[str] = [
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
# "lllyasviel/control_v11p_sd15_seg",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
]
|
||||
|
||||
SD_BRUSHNET_CHOICES: List[str] = [
|
||||
"Sanster/brushnet_random_mask",
|
||||
"Sanster/brushnet_segmentation_mask",
|
||||
]
|
||||
|
||||
SD2_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
]
|
||||
|
||||
SDXL_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-openpose-sdxl-1.0",
|
||||
"destitech/controlnet-inpaint-dreamer-sdxl",
|
||||
"diffusers/controlnet-canny-sdxl-1.0",
|
||||
"diffusers/controlnet-canny-sdxl-1.0-mid",
|
||||
"diffusers/controlnet-canny-sdxl-1.0-small",
|
||||
"diffusers/controlnet-depth-sdxl-1.0",
|
||||
"diffusers/controlnet-depth-sdxl-1.0-mid",
|
||||
"diffusers/controlnet-depth-sdxl-1.0-small",
|
||||
]
|
||||
|
||||
LOCAL_FILES_ONLY_HELP = """
|
||||
When loading diffusion models, using local files only, not connect to HuggingFace server.
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_DIR = os.path.abspath(
|
||||
os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
|
||||
)
|
||||
|
||||
MODEL_DIR_HELP = f"""
|
||||
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
|
||||
"""
|
||||
|
||||
OUTPUT_DIR_HELP = """
|
||||
Result images will be saved to output directory automatically.
|
||||
"""
|
||||
|
||||
MASK_DIR_HELP = """
|
||||
You can view masks in FileManager
|
||||
"""
|
||||
|
||||
INPUT_HELP = """
|
||||
If input is image, it will be loaded by default.
|
||||
If input is directory, you can browse and select image in file manager.
|
||||
"""
|
||||
|
||||
GUI_HELP = """
|
||||
Launch Lama Cleaner as desktop app
|
||||
"""
|
||||
|
||||
QUALITY_HELP = """
|
||||
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
|
||||
"""
|
||||
|
||||
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
||||
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
||||
REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU"
|
||||
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
|
||||
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
||||
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
|
||||
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
|
||||
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
||||
|
||||
INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser"
|
@ -1,313 +0,0 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from iopaint.schema import ModelType, ModelInfo
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.const import (
|
||||
DEFAULT_MODEL_DIR,
|
||||
DIFFUSERS_SD_CLASS_NAME,
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.model.original_sd_configs import get_config_files
|
||||
|
||||
|
||||
def cli_download_model(model: str):
|
||||
from iopaint.model import models
|
||||
from iopaint.model.utils import handle_from_pretrained_exceptions
|
||||
|
||||
if model in models and models[model].is_erase_model:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info("Done.")
|
||||
elif model == ANYTEXT_NAME:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info("Done.")
|
||||
else:
|
||||
logger.info(f"Downloading model from Huggingface: {model}")
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
downloaded_path = handle_from_pretrained_exceptions(
|
||||
DiffusionPipeline.download,
|
||||
pretrained_model_name=model,
|
||||
variant="fp16",
|
||||
resume_download=True,
|
||||
)
|
||||
logger.info(f"Done. Downloaded to {downloaded_path}")
|
||||
|
||||
|
||||
def folder_name_to_show_name(name: str) -> str:
|
||||
return name.replace("models--", "").replace("--", "/")
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
|
||||
if "inpaint" in Path(model_abs_path).name.lower():
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
else:
|
||||
# load once to check num_in_channels
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
try:
|
||||
StableDiffusionInpaintPipeline.from_single_file(
|
||||
model_abs_path,
|
||||
load_safety_checker=False,
|
||||
num_in_channels=9,
|
||||
original_config_file=get_config_files()['v1']
|
||||
)
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
except ValueError as e:
|
||||
if "[320, 4, 3, 3]" in str(e):
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
else:
|
||||
logger.info(f"Ignore non sdxl file: {model_abs_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {model_abs_path}: {e}")
|
||||
return
|
||||
return model_type
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
|
||||
if "inpaint" in model_abs_path:
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
# load once to check num_in_channels
|
||||
from diffusers import StableDiffusionXLInpaintPipeline
|
||||
|
||||
try:
|
||||
model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||
model_abs_path,
|
||||
load_safety_checker=False,
|
||||
num_in_channels=9,
|
||||
original_config_file=get_config_files()['xl'],
|
||||
)
|
||||
if model.unet.config.in_channels == 9:
|
||||
# https://github.com/huggingface/diffusers/issues/6610
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
except ValueError as e:
|
||||
if "[320, 4, 3, 3]" in str(e):
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
else:
|
||||
logger.info(f"Ignore non sdxl file: {model_abs_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {model_abs_path}: {e}")
|
||||
return
|
||||
return model_type
|
||||
|
||||
|
||||
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
||||
cache_file = stable_diffusion_dir / "iopaint_cache.json"
|
||||
model_type_cache = {}
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
model_type_cache = json.load(f)
|
||||
assert isinstance(model_type_cache, dict)
|
||||
except:
|
||||
pass
|
||||
|
||||
res = []
|
||||
for it in stable_diffusion_dir.glob("*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
model_abs_path = str(it.absolute())
|
||||
model_type = model_type_cache.get(it.name)
|
||||
if model_type is None:
|
||||
model_type = get_sd_model_type(model_abs_path)
|
||||
if model_type is None:
|
||||
continue
|
||||
|
||||
model_type_cache[it.name] = model_type
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
path=model_abs_path,
|
||||
model_type=model_type,
|
||||
is_single_file_diffusers=True,
|
||||
)
|
||||
)
|
||||
if stable_diffusion_dir.exists():
|
||||
with open(cache_file, "w", encoding="utf-8") as fw:
|
||||
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
|
||||
|
||||
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
|
||||
sdxl_model_type_cache = {}
|
||||
if sdxl_cache_file.exists():
|
||||
try:
|
||||
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
|
||||
sdxl_model_type_cache = json.load(f)
|
||||
assert isinstance(sdxl_model_type_cache, dict)
|
||||
except:
|
||||
pass
|
||||
|
||||
for it in stable_diffusion_xl_dir.glob("*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
model_abs_path = str(it.absolute())
|
||||
model_type = sdxl_model_type_cache.get(it.name)
|
||||
if model_type is None:
|
||||
model_type = get_sdxl_model_type(model_abs_path)
|
||||
if model_type is None:
|
||||
continue
|
||||
|
||||
sdxl_model_type_cache[it.name] = model_type
|
||||
if stable_diffusion_xl_dir.exists():
|
||||
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
|
||||
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
|
||||
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
path=model_abs_path,
|
||||
model_type=model_type,
|
||||
is_single_file_diffusers=True,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
|
||||
res = []
|
||||
from iopaint.model import models
|
||||
|
||||
# logger.info(f"Scanning inpaint models in {model_dir}")
|
||||
|
||||
for name, m in models.items():
|
||||
if m.is_erase_model and m.is_downloaded():
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
path=name,
|
||||
model_type=ModelType.INPAINT,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def scan_diffusers_models() -> List[ModelInfo]:
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
available_models = []
|
||||
cache_dir = Path(HF_HUB_CACHE)
|
||||
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
||||
diffusers_model_names = []
|
||||
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||
for it in model_index_files:
|
||||
it = Path(it)
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except:
|
||||
continue
|
||||
|
||||
_class_name = data["_class_name"]
|
||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||
if name in diffusers_model_names:
|
||||
continue
|
||||
if "PowerPaint" in name:
|
||||
model_type = ModelType.DIFFUSERS_OTHER
|
||||
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
elif _class_name in [
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"AnyText",
|
||||
]:
|
||||
model_type = ModelType.DIFFUSERS_OTHER
|
||||
else:
|
||||
continue
|
||||
|
||||
diffusers_model_names.append(name)
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
path=name,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return available_models
|
||||
|
||||
|
||||
def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
available_models = []
|
||||
diffusers_model_names = []
|
||||
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||
for it in model_index_files:
|
||||
it = Path(it)
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except:
|
||||
logger.error(
|
||||
f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
|
||||
)
|
||||
continue
|
||||
|
||||
_class_name = data["_class_name"]
|
||||
name = folder_name_to_show_name(it.parent.name)
|
||||
if name in diffusers_model_names:
|
||||
continue
|
||||
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
continue
|
||||
|
||||
diffusers_model_names.append(name)
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
path=str(it.parent.absolute()),
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return available_models
|
||||
|
||||
|
||||
def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
available_models = []
|
||||
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
||||
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
|
||||
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
|
||||
return available_models
|
||||
|
||||
|
||||
def scan_models() -> List[ModelInfo]:
|
||||
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
||||
available_models = []
|
||||
available_models.extend(scan_inpaint_models(model_dir))
|
||||
available_models.extend(scan_single_file_diffusion_models(model_dir))
|
||||
available_models.extend(scan_diffusers_models())
|
||||
available_models.extend(scan_converted_diffusers_models(model_dir))
|
||||
return available_models
|
@ -1 +0,0 @@
|
||||
from .file_manager import FileManager
|
@ -1,218 +0,0 @@
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from ..schema import MediasResponse, MediaTab
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
from .storage_backends import FilesystemStorageBackend
|
||||
from .utils import aspect_to_string, generate_filename, glob_img
|
||||
|
||||
|
||||
class FileManager:
|
||||
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
|
||||
self.app = app
|
||||
self.input_dir: Path = input_dir
|
||||
self.mask_dir: Path = mask_dir
|
||||
self.output_dir: Path = output_dir
|
||||
|
||||
self.image_dir_filenames = []
|
||||
self.output_dir_filenames = []
|
||||
if not self.thumbnail_directory.exists():
|
||||
self.thumbnail_directory.mkdir(parents=True)
|
||||
|
||||
# fmt: off
|
||||
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
|
||||
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
|
||||
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
|
||||
# fmt: on
|
||||
|
||||
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
|
||||
img_dir = self._get_dir(tab)
|
||||
return self._media_names(img_dir)
|
||||
|
||||
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
|
||||
file_path = self._get_file(tab, filename)
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
|
||||
def api_media_thumbnail_file(
|
||||
self, tab: MediaTab, filename: str, width: int, height: int
|
||||
) -> FileResponse:
|
||||
img_dir = self._get_dir(tab)
|
||||
thumb_filename, (width, height) = self.get_thumbnail(
|
||||
img_dir, filename, width=width, height=height
|
||||
)
|
||||
thumbnail_filepath = self.thumbnail_directory / thumb_filename
|
||||
return FileResponse(
|
||||
thumbnail_filepath,
|
||||
headers={
|
||||
"X-Width": str(width),
|
||||
"X-Height": str(height),
|
||||
},
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
|
||||
def _get_dir(self, tab: MediaTab) -> Path:
|
||||
if tab == "input":
|
||||
return self.input_dir
|
||||
elif tab == "output":
|
||||
return self.output_dir
|
||||
elif tab == "mask":
|
||||
return self.mask_dir
|
||||
else:
|
||||
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
||||
|
||||
def _get_file(self, tab: MediaTab, filename: str) -> Path:
|
||||
file_path = self._get_dir(tab) / filename
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
|
||||
return file_path
|
||||
|
||||
@property
|
||||
def thumbnail_directory(self) -> Path:
|
||||
return self.output_dir / "thumbnails"
|
||||
|
||||
@staticmethod
|
||||
def _media_names(directory: Path) -> List[MediasResponse]:
|
||||
names = sorted([it.name for it in glob_img(directory)])
|
||||
res = []
|
||||
for name in names:
|
||||
path = os.path.join(directory, name)
|
||||
img = Image.open(path)
|
||||
res.append(
|
||||
MediasResponse(
|
||||
name=name,
|
||||
height=img.height,
|
||||
width=img.width,
|
||||
ctime=os.path.getctime(path),
|
||||
mtime=os.path.getmtime(path),
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def get_thumbnail(
|
||||
self, directory: Path, original_filename: str, width, height, **options
|
||||
):
|
||||
directory = Path(directory)
|
||||
storage = FilesystemStorageBackend(self.app)
|
||||
crop = options.get("crop", "fit")
|
||||
background = options.get("background")
|
||||
quality = options.get("quality", 90)
|
||||
|
||||
original_path, original_filename = os.path.split(original_filename)
|
||||
original_filepath = os.path.join(directory, original_path, original_filename)
|
||||
image = Image.open(BytesIO(storage.read(original_filepath)))
|
||||
|
||||
# keep ratio resize
|
||||
if not width and not height:
|
||||
width = 256
|
||||
|
||||
if width != 0:
|
||||
height = int(image.height * width / image.width)
|
||||
else:
|
||||
width = int(image.width * height / image.height)
|
||||
|
||||
thumbnail_size = (width, height)
|
||||
|
||||
thumbnail_filename = generate_filename(
|
||||
directory,
|
||||
original_filename,
|
||||
aspect_to_string(thumbnail_size),
|
||||
crop,
|
||||
background,
|
||||
quality,
|
||||
)
|
||||
|
||||
thumbnail_filepath = os.path.join(
|
||||
self.thumbnail_directory, original_path, thumbnail_filename
|
||||
)
|
||||
|
||||
if storage.exists(thumbnail_filepath):
|
||||
return thumbnail_filepath, (width, height)
|
||||
|
||||
try:
|
||||
image.load()
|
||||
except (IOError, OSError):
|
||||
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
||||
return thumbnail_filepath, (width, height)
|
||||
|
||||
# get original image format
|
||||
options["format"] = options.get("format", image.format)
|
||||
|
||||
image = self._create_thumbnail(
|
||||
image, thumbnail_size, crop, background=background
|
||||
)
|
||||
|
||||
raw_data = self.get_raw_data(image, **options)
|
||||
storage.save(thumbnail_filepath, raw_data)
|
||||
|
||||
return thumbnail_filepath, (width, height)
|
||||
|
||||
def get_raw_data(self, image, **options):
|
||||
data = {
|
||||
"format": self._get_format(image, **options),
|
||||
"quality": options.get("quality", 90),
|
||||
}
|
||||
|
||||
_file = BytesIO()
|
||||
image.save(_file, **data)
|
||||
return _file.getvalue()
|
||||
|
||||
@staticmethod
|
||||
def colormode(image, colormode="RGB"):
|
||||
if colormode == "RGB" or colormode == "RGBA":
|
||||
if image.mode == "RGBA":
|
||||
return image
|
||||
if image.mode == "LA":
|
||||
return image.convert("RGBA")
|
||||
return image.convert(colormode)
|
||||
|
||||
if colormode == "GRAY":
|
||||
return image.convert("L")
|
||||
|
||||
return image.convert(colormode)
|
||||
|
||||
@staticmethod
|
||||
def background(original_image, color=0xFF):
|
||||
size = (max(original_image.size),) * 2
|
||||
image = Image.new("L", size, color)
|
||||
image.paste(
|
||||
original_image,
|
||||
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def _get_format(self, image, **options):
|
||||
if options.get("format"):
|
||||
return options.get("format")
|
||||
if image.format:
|
||||
return image.format
|
||||
|
||||
return "JPEG"
|
||||
|
||||
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
||||
try:
|
||||
resample = Image.Resampling.LANCZOS
|
||||
except AttributeError: # pylint: disable=raise-missing-from
|
||||
resample = Image.ANTIALIAS
|
||||
|
||||
if crop == "fit":
|
||||
image = ImageOps.fit(image, size, resample)
|
||||
else:
|
||||
image = image.copy()
|
||||
image.thumbnail(size, resample=resample)
|
||||
|
||||
if background is not None:
|
||||
image = self.background(image)
|
||||
|
||||
image = self.colormode(image)
|
||||
|
||||
return image
|
@ -1,46 +0,0 @@
|
||||
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
|
||||
import errno
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseStorageBackend(ABC):
|
||||
def __init__(self, app=None):
|
||||
self.app = app
|
||||
|
||||
@abstractmethod
|
||||
def read(self, filepath, mode="rb", **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filepath, data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FilesystemStorageBackend(BaseStorageBackend):
|
||||
def read(self, filepath, mode="rb", **kwargs):
|
||||
with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
|
||||
return f.read()
|
||||
|
||||
def exists(self, filepath):
|
||||
return os.path.exists(filepath)
|
||||
|
||||
def save(self, filepath, data):
|
||||
directory = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(directory):
|
||||
try:
|
||||
os.makedirs(directory)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
if not os.path.isdir(directory):
|
||||
raise IOError("{} is not a directory".format(directory))
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(data)
|
@ -1,65 +0,0 @@
|
||||
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
def generate_filename(directory: Path, original_filename, *options) -> str:
|
||||
text = str(directory.absolute()) + original_filename
|
||||
for v in options:
|
||||
text += "%s" % v
|
||||
md5_hash = hashlib.md5()
|
||||
md5_hash.update(text.encode("utf-8"))
|
||||
return md5_hash.hexdigest() + ".jpg"
|
||||
|
||||
|
||||
def parse_size(size):
|
||||
if isinstance(size, int):
|
||||
# If the size parameter is a single number, assume square aspect.
|
||||
return [size, size]
|
||||
|
||||
if isinstance(size, (tuple, list)):
|
||||
if len(size) == 1:
|
||||
# If single value tuple/list is provided, exand it to two elements
|
||||
return size + type(size)(size)
|
||||
return size
|
||||
|
||||
try:
|
||||
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
||||
except ValueError:
|
||||
raise ValueError( # pylint: disable=raise-missing-from
|
||||
"Bad thumbnail size format. Valid format is INTxINT."
|
||||
)
|
||||
|
||||
if len(thumbnail_size) == 1:
|
||||
# If the size parameter only contains a single integer, assume square aspect.
|
||||
thumbnail_size.append(thumbnail_size[0])
|
||||
|
||||
return thumbnail_size
|
||||
|
||||
|
||||
def aspect_to_string(size):
|
||||
if isinstance(size, str):
|
||||
return size
|
||||
|
||||
return "x".join(map(str, size))
|
||||
|
||||
|
||||
IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
||||
|
||||
|
||||
def glob_img(p: Union[Path, str], recursive: bool = False):
|
||||
p = Path(p)
|
||||
if p.is_file() and p.suffix in IMG_SUFFIX:
|
||||
yield p
|
||||
else:
|
||||
if recursive:
|
||||
files = Path(p).glob("**/*.*")
|
||||
else:
|
||||
files = Path(p).glob("*.*")
|
||||
|
||||
for it in files:
|
||||
if it.suffix not in IMG_SUFFIX:
|
||||
continue
|
||||
yield it
|
@ -1,408 +0,0 @@
|
||||
import base64
|
||||
import imghdr
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
from urllib.parse import urlparse
|
||||
import cv2
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopaint.const import MPS_UNSUPPORT_MODELS
|
||||
from loguru import logger
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
import hashlib
|
||||
|
||||
|
||||
def md5sum(filename):
|
||||
md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def switch_mps_device(model_name, device):
|
||||
if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
|
||||
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||
return torch.device("cpu")
|
||||
return device
|
||||
|
||||
|
||||
def get_cache_path_by_url(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
return cached_file
|
||||
|
||||
|
||||
def download_model(url, model_md5: str = None):
|
||||
if os.path.exists(url):
|
||||
cached_file = url
|
||||
else:
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
||||
if model_md5:
|
||||
_md5 = md5sum(cached_file)
|
||||
if model_md5 == _md5:
|
||||
logger.info(f"Download model success, md5: {_md5}")
|
||||
else:
|
||||
try:
|
||||
os.remove(cached_file)
|
||||
logger.error(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
|
||||
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
)
|
||||
except:
|
||||
logger.error(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint."
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
def ceil_modulo(x, mod):
|
||||
if x % mod == 0:
|
||||
return x
|
||||
return (x // mod + 1) * mod
|
||||
|
||||
|
||||
def handle_error(model_path, model_md5, e):
|
||||
_md5 = md5sum(model_path)
|
||||
if _md5 != model_md5:
|
||||
try:
|
||||
os.remove(model_path)
|
||||
logger.error(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint."
|
||||
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
)
|
||||
except:
|
||||
logger.error(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to load model {model_path},"
|
||||
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device, model_md5: str):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
try:
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
except Exception as e:
|
||||
handle_error(model_path, model_md5, e)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
try:
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
except Exception as e:
|
||||
handle_error(model_path, model_md5, e)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||
data = cv2.imencode(
|
||||
f".{ext}",
|
||||
image_numpy,
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||
)[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
|
||||
with io.BytesIO() as output:
|
||||
kwargs = {k: v for k, v in infos.items() if v is not None}
|
||||
if ext == "jpg":
|
||||
ext = "jpeg"
|
||||
if "png" == ext.lower() and "parameters" in kwargs:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
pnginfo_data.add_text("parameters", kwargs["parameters"])
|
||||
kwargs["pnginfo"] = pnginfo_data
|
||||
|
||||
pil_img.save(output, format=ext, quality=quality, **kwargs)
|
||||
image_bytes = output.getvalue()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def load_img(img_bytes, gray: bool = False, return_info: bool = False):
|
||||
alpha_channel = None
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
if return_info:
|
||||
infos = image.info
|
||||
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if gray:
|
||||
image = image.convert("L")
|
||||
np_img = np.array(image)
|
||||
else:
|
||||
if image.mode == "RGBA":
|
||||
np_img = np.array(image)
|
||||
alpha_channel = np_img[:, :, -1]
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
||||
else:
|
||||
image = image.convert("RGB")
|
||||
np_img = np.array(image)
|
||||
|
||||
if return_info:
|
||||
return np_img, alpha_channel, infos
|
||||
return np_img, alpha_channel
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
np_img = np_img[:, :, np.newaxis]
|
||||
np_img = np.transpose(np_img, (2, 0, 1))
|
||||
np_img = np_img.astype("float32") / 255
|
||||
return np_img
|
||||
|
||||
|
||||
def resize_max_size(
|
||||
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
||||
) -> np.ndarray:
|
||||
# Resize image's longer size to size_limit if longer size larger than size_limit
|
||||
h, w = np_img.shape[:2]
|
||||
if max(h, w) > size_limit:
|
||||
ratio = size_limit / max(h, w)
|
||||
new_w = int(w * ratio + 0.5)
|
||||
new_h = int(h * ratio + 0.5)
|
||||
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
||||
else:
|
||||
return np_img
|
||||
|
||||
|
||||
def pad_img_to_modulo(
|
||||
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img: [H, W, C]
|
||||
mod:
|
||||
square: 是否为正方形
|
||||
min_size:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
height, width = img.shape[:2]
|
||||
out_height = ceil_modulo(height, mod)
|
||||
out_width = ceil_modulo(width, mod)
|
||||
|
||||
if min_size is not None:
|
||||
assert min_size % mod == 0
|
||||
out_width = max(min_size, out_width)
|
||||
out_height = max(min_size, out_height)
|
||||
|
||||
if square:
|
||||
max_size = max(out_height, out_width)
|
||||
out_height = max_size
|
||||
out_width = max_size
|
||||
|
||||
return np.pad(
|
||||
img,
|
||||
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||
mode="symmetric",
|
||||
)
|
||||
|
||||
|
||||
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
mask: (h, w, 1) 0~255
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
height, width = mask.shape[:2]
|
||||
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
boxes = []
|
||||
for cnt in contours:
|
||||
x, y, w, h = cv2.boundingRect(cnt)
|
||||
box = np.array([x, y, x + w, y + h]).astype(int)
|
||||
|
||||
box[::2] = np.clip(box[::2], 0, width)
|
||||
box[1::2] = np.clip(box[1::2], 0, height)
|
||||
boxes.append(box)
|
||||
|
||||
return boxes
|
||||
|
||||
|
||||
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
mask: (h, w) 0~255
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
max_area = 0
|
||||
max_index = -1
|
||||
for i, cnt in enumerate(contours):
|
||||
area = cv2.contourArea(cnt)
|
||||
if area > max_area:
|
||||
max_area = area
|
||||
max_index = i
|
||||
|
||||
if max_index != -1:
|
||||
new_mask = np.zeros_like(mask)
|
||||
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
def is_mac():
|
||||
return sys.platform == "darwin"
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
w = imghdr.what("", img_bytes)
|
||||
if w is None:
|
||||
w = "jpeg"
|
||||
return w
|
||||
|
||||
|
||||
def decode_base64_to_image(
|
||||
encoding: str, gray=False
|
||||
) -> Tuple[np.array, Optional[np.array], Dict]:
|
||||
if encoding.startswith("data:image/") or encoding.startswith(
|
||||
"data:application/octet-stream;base64,"
|
||||
):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
|
||||
|
||||
alpha_channel = None
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except:
|
||||
pass
|
||||
# exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
|
||||
infos = image.info
|
||||
|
||||
if gray:
|
||||
image = image.convert("L")
|
||||
np_img = np.array(image)
|
||||
else:
|
||||
if image.mode == "RGBA":
|
||||
np_img = np.array(image)
|
||||
alpha_channel = np_img[:, :, -1]
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
||||
else:
|
||||
image = image.convert("RGB")
|
||||
np_img = np.array(image)
|
||||
|
||||
return np_img, alpha_channel, infos
|
||||
|
||||
|
||||
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
|
||||
img_bytes = pil_to_bytes(
|
||||
image,
|
||||
"png",
|
||||
quality=quality,
|
||||
infos=infos,
|
||||
)
|
||||
return base64.b64encode(img_bytes)
|
||||
|
||||
|
||||
def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
|
||||
alpha_channel = cv2.resize(
|
||||
alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
|
||||
)
|
||||
rgb_np_img = np.concatenate(
|
||||
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||
)
|
||||
return rgb_np_img
|
||||
|
||||
|
||||
def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
|
||||
# fronted brush color "ffcc00bb"
|
||||
# kernel_size = kernel_size*2+1
|
||||
mask[mask >= 127] = 255
|
||||
mask[mask < 127] = 0
|
||||
|
||||
if operate == "reverse":
|
||||
mask = 255 - mask
|
||||
else:
|
||||
kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
|
||||
)
|
||||
if operate == "expand":
|
||||
mask = cv2.dilate(
|
||||
mask,
|
||||
kernel,
|
||||
iterations=1,
|
||||
)
|
||||
else:
|
||||
mask = cv2.erode(
|
||||
mask,
|
||||
kernel,
|
||||
iterations=1,
|
||||
)
|
||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
||||
res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
|
||||
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
||||
return res_mask
|
||||
|
||||
|
||||
def gen_frontend_mask(bgr_or_gray_mask):
|
||||
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
|
||||
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# fronted brush color "ffcc00bb"
|
||||
# TODO: how to set kernel size?
|
||||
kernel_size = 9
|
||||
bgr_or_gray_mask = cv2.dilate(
|
||||
bgr_or_gray_mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8),
|
||||
iterations=1,
|
||||
)
|
||||
res_mask = np.zeros(
|
||||
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
|
||||
)
|
||||
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
|
||||
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
||||
return res_mask
|
@ -1,10 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def install(package):
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
||||
|
||||
|
||||
def install_plugins_package():
|
||||
install("rembg")
|
@ -1,37 +0,0 @@
|
||||
from .anytext.anytext_model import AnyText
|
||||
from .controlnet import ControlNet
|
||||
from .fcf import FcF
|
||||
from .instruct_pix2pix import InstructPix2Pix
|
||||
from .kandinsky import Kandinsky22
|
||||
from .lama import LaMa
|
||||
from .ldm import LDM
|
||||
from .manga import Manga
|
||||
from .mat import MAT
|
||||
from .mi_gan import MIGAN
|
||||
from .opencv2 import OpenCV2
|
||||
from .paint_by_example import PaintByExample
|
||||
from .power_paint.power_paint import PowerPaint
|
||||
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
|
||||
from .sdxl import SDXL
|
||||
from .zits import ZITS
|
||||
|
||||
models = {
|
||||
LaMa.name: LaMa,
|
||||
LDM.name: LDM,
|
||||
ZITS.name: ZITS,
|
||||
MAT.name: MAT,
|
||||
FcF.name: FcF,
|
||||
OpenCV2.name: OpenCV2,
|
||||
Manga.name: Manga,
|
||||
MIGAN.name: MIGAN,
|
||||
SD15.name: SD15,
|
||||
Anything4.name: Anything4,
|
||||
RealisticVision14.name: RealisticVision14,
|
||||
SD2.name: SD2,
|
||||
PaintByExample.name: PaintByExample,
|
||||
InstructPix2Pix.name: InstructPix2Pix,
|
||||
Kandinsky22.name: Kandinsky22,
|
||||
SDXL.name: SDXL,
|
||||
PowerPaint.name: PowerPaint,
|
||||
AnyText.name: AnyText,
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from iopaint.const import ANYTEXT_NAME
|
||||
from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
|
||||
from iopaint.model.base import DiffusionInpaintModel
|
||||
from iopaint.model.utils import get_torch_dtype, is_local_files_only
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
|
||||
class AnyText(DiffusionInpaintModel):
|
||||
name = ANYTEXT_NAME
|
||||
pad_mod = 64
|
||||
is_erase_model = False
|
||||
|
||||
@staticmethod
|
||||
def download(local_files_only=False):
|
||||
hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="model_index.json",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
ckpt_path = hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="pytorch_model.fp16.safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
font_path = hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="SourceHanSansSC-Medium.otf",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return ckpt_path, font_path
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
local_files_only = is_local_files_only(**kwargs)
|
||||
ckpt_path, font_path = self.download(local_files_only)
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.model = AnyTextPipeline(
|
||||
ckpt_path=ckpt_path,
|
||||
font_path=font_path,
|
||||
device=device,
|
||||
use_fp16=torch_dtype == torch.float16,
|
||||
)
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to inpainting
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
mask = mask.astype("float32") / 255.0
|
||||
masked_image = image * (1 - mask)
|
||||
|
||||
# list of rgb ndarray
|
||||
results, rtn_code, rtn_warning = self.model(
|
||||
image=image,
|
||||
masked_image=masked_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
height=height,
|
||||
width=width,
|
||||
seed=config.sd_seed,
|
||||
sort_priority="y",
|
||||
callback=self.callback
|
||||
)
|
||||
inpainted_rgb_image = results[0][..., ::-1]
|
||||
return inpainted_rgb_image
|
@ -1,403 +0,0 @@
|
||||
"""
|
||||
AnyText: Multilingual Visual Text Generation And Editing
|
||||
Paper: https://arxiv.org/abs/2311.03054
|
||||
Code: https://github.com/tyxsspa/AnyText
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.model.utils import set_seed
|
||||
from safetensors.torch import load_file
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import torch
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import einops
|
||||
from PIL import ImageFont
|
||||
from iopaint.model.anytext.cldm.model import create_model, load_state_dict
|
||||
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
|
||||
from iopaint.model.anytext.utils import (
|
||||
check_channels,
|
||||
draw_glyph,
|
||||
draw_glyph2,
|
||||
)
|
||||
|
||||
|
||||
BBOX_MAX_NUM = 8
|
||||
PLACE_HOLDER = "*"
|
||||
max_chars = 20
|
||||
|
||||
ANYTEXT_CFG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
|
||||
)
|
||||
|
||||
|
||||
def check_limits(tensor):
|
||||
float16_min = torch.finfo(torch.float16).min
|
||||
float16_max = torch.finfo(torch.float16).max
|
||||
|
||||
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
|
||||
is_below_min = (tensor < float16_min).any()
|
||||
is_above_max = (tensor > float16_max).any()
|
||||
|
||||
return is_below_min or is_above_max
|
||||
|
||||
|
||||
class AnyTextPipeline:
|
||||
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
|
||||
self.cfg_path = ANYTEXT_CFG
|
||||
self.font_path = font_path
|
||||
self.use_fp16 = use_fp16
|
||||
self.device = device
|
||||
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
self.model = create_model(
|
||||
self.cfg_path,
|
||||
device=self.device,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
if Path(ckpt_path).suffix == ".safetensors":
|
||||
state_dict = load_file(ckpt_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_state_dict(ckpt_path, location="cpu")
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model = self.model.eval().to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: np.ndarray,
|
||||
masked_image: np.ndarray,
|
||||
num_inference_steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
sort_priority: str = "y",
|
||||
callback=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt:
|
||||
negative_prompt:
|
||||
image:
|
||||
masked_image:
|
||||
num_inference_steps:
|
||||
strength:
|
||||
guidance_scale:
|
||||
height:
|
||||
width:
|
||||
seed:
|
||||
sort_priority: x: left-right, y: top-down
|
||||
|
||||
Returns:
|
||||
result: list of images in numpy.ndarray format
|
||||
rst_code: 0: normal -1: error 1:warning
|
||||
rst_info: string of error or warning
|
||||
|
||||
"""
|
||||
set_seed(seed)
|
||||
str_warning = ""
|
||||
|
||||
mode = "text-editing"
|
||||
revise_pos = False
|
||||
img_count = 1
|
||||
ddim_steps = num_inference_steps
|
||||
w = width
|
||||
h = height
|
||||
strength = strength
|
||||
cfg_scale = guidance_scale
|
||||
eta = 0.0
|
||||
|
||||
prompt, texts = self.modify_prompt(prompt)
|
||||
if prompt is None and texts is None:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
"You have input Chinese prompt but the translator is not loaded!",
|
||||
"",
|
||||
)
|
||||
n_lines = len(texts)
|
||||
if mode in ["text-generation", "gen"]:
|
||||
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
|
||||
elif mode in ["text-editing", "edit"]:
|
||||
if masked_image is None or image is None:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
"Reference image and position image are needed for text editing!",
|
||||
"",
|
||||
)
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)[..., ::-1]
|
||||
assert image is not None, f"Can't read ori_image image from{image}!"
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(image)}"
|
||||
edit_image = image.clip(1, 255) # for mask reason
|
||||
edit_image = check_channels(edit_image)
|
||||
# edit_image = resize_image(
|
||||
# edit_image, max_length=768
|
||||
# ) # make w h multiple of 64, resize if w or h > max_length
|
||||
h, w = edit_image.shape[:2] # change h, w by input ref_img
|
||||
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
||||
if masked_image is None:
|
||||
pos_imgs = np.zeros((w, h, 1))
|
||||
if isinstance(masked_image, str):
|
||||
masked_image = cv2.imread(masked_image)[..., ::-1]
|
||||
assert (
|
||||
masked_image is not None
|
||||
), f"Can't read draw_pos image from{masked_image}!"
|
||||
pos_imgs = 255 - masked_image
|
||||
elif isinstance(masked_image, torch.Tensor):
|
||||
pos_imgs = masked_image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
masked_image, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(masked_image)}"
|
||||
pos_imgs = 255 - masked_image
|
||||
pos_imgs = pos_imgs[..., 0:1]
|
||||
pos_imgs = cv2.convertScaleAbs(pos_imgs)
|
||||
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
|
||||
# seprate pos_imgs
|
||||
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
|
||||
if len(pos_imgs) == 0:
|
||||
pos_imgs = [np.zeros((h, w, 1))]
|
||||
if len(pos_imgs) < n_lines:
|
||||
if n_lines == 1 and texts[0] == " ":
|
||||
pass # text-to-image without text
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
|
||||
)
|
||||
elif len(pos_imgs) > n_lines:
|
||||
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
|
||||
# get pre_pos, poly_list, hint that needed for anytext
|
||||
pre_pos = []
|
||||
poly_list = []
|
||||
for input_pos in pos_imgs:
|
||||
if input_pos.mean() != 0:
|
||||
input_pos = (
|
||||
input_pos[..., np.newaxis]
|
||||
if len(input_pos.shape) == 2
|
||||
else input_pos
|
||||
)
|
||||
poly, pos_img = self.find_polygon(input_pos)
|
||||
pre_pos += [pos_img / 255.0]
|
||||
poly_list += [poly]
|
||||
else:
|
||||
pre_pos += [np.zeros((h, w, 1))]
|
||||
poly_list += [None]
|
||||
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
|
||||
# prepare info dict
|
||||
info = {}
|
||||
info["glyphs"] = []
|
||||
info["gly_line"] = []
|
||||
info["positions"] = []
|
||||
info["n_lines"] = [len(texts)] * img_count
|
||||
gly_pos_imgs = []
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if len(text) > max_chars:
|
||||
str_warning = (
|
||||
f'"{text}" length > max_chars: {max_chars}, will be cut off...'
|
||||
)
|
||||
text = text[:max_chars]
|
||||
gly_scale = 2
|
||||
if pre_pos[i].mean() != 0:
|
||||
gly_line = draw_glyph(self.font, text)
|
||||
glyphs = draw_glyph2(
|
||||
self.font,
|
||||
text,
|
||||
poly_list[i],
|
||||
scale=gly_scale,
|
||||
width=w,
|
||||
height=h,
|
||||
add_space=False,
|
||||
)
|
||||
gly_pos_img = cv2.drawContours(
|
||||
glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
|
||||
)
|
||||
if revise_pos:
|
||||
resize_gly = cv2.resize(
|
||||
glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
|
||||
)
|
||||
new_pos = cv2.morphologyEx(
|
||||
(resize_gly * 255).astype(np.uint8),
|
||||
cv2.MORPH_CLOSE,
|
||||
kernel=np.ones(
|
||||
(resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
iterations=1,
|
||||
)
|
||||
new_pos = (
|
||||
new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
|
||||
)
|
||||
contours, _ = cv2.findContours(
|
||||
new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
||||
)
|
||||
if len(contours) != 1:
|
||||
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
|
||||
else:
|
||||
rect = cv2.minAreaRect(contours[0])
|
||||
poly = np.int0(cv2.boxPoints(rect))
|
||||
pre_pos[i] = (
|
||||
cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
|
||||
)
|
||||
gly_pos_img = cv2.drawContours(
|
||||
glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
|
||||
)
|
||||
gly_pos_imgs += [gly_pos_img] # for show
|
||||
else:
|
||||
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
|
||||
gly_line = np.zeros((80, 512, 1))
|
||||
gly_pos_imgs += [
|
||||
np.zeros((h * gly_scale, w * gly_scale, 1))
|
||||
] # for show
|
||||
pos = pre_pos[i]
|
||||
info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
|
||||
info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
|
||||
info["positions"] += [self.arr2tensor(pos, img_count)]
|
||||
# get masked_x
|
||||
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
||||
masked_img = np.transpose(masked_img, (2, 0, 1))
|
||||
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
|
||||
if self.use_fp16:
|
||||
masked_img = masked_img.half()
|
||||
encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
|
||||
masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
|
||||
if self.use_fp16:
|
||||
masked_x = masked_x.half()
|
||||
info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
|
||||
|
||||
hint = self.arr2tensor(np_hint, img_count)
|
||||
cond = self.model.get_learned_conditioning(
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
un_cond = self.model.get_learned_conditioning(
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[negative_prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
shape = (4, h // 8, w // 8)
|
||||
self.model.control_scales = [strength] * 13
|
||||
samples, intermediates = self.ddim_sampler.sample(
|
||||
ddim_steps,
|
||||
img_count,
|
||||
shape,
|
||||
cond,
|
||||
verbose=False,
|
||||
eta=eta,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=un_cond,
|
||||
callback=callback
|
||||
)
|
||||
if self.use_fp16:
|
||||
samples = samples.half()
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = (
|
||||
(einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
results = [x_samples[i] for i in range(img_count)]
|
||||
# if (
|
||||
# mode == "edit" and False
|
||||
# ): # replace backgound in text editing but not ideal yet
|
||||
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
# if len(gly_pos_imgs) > 0 and show_debug:
|
||||
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
# glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
# results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
rst_code = 1 if str_warning else 0
|
||||
return results, rst_code, str_warning
|
||||
|
||||
def modify_prompt(self, prompt):
|
||||
prompt = prompt.replace("“", '"')
|
||||
prompt = prompt.replace("”", '"')
|
||||
p = '"(.*?)"'
|
||||
strs = re.findall(p, prompt)
|
||||
if len(strs) == 0:
|
||||
strs = [" "]
|
||||
else:
|
||||
for s in strs:
|
||||
prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
|
||||
# if self.is_chinese(prompt):
|
||||
# if self.trans_pipe is None:
|
||||
# return None, None
|
||||
# old_prompt = prompt
|
||||
# prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
|
||||
# print(f"Translate: {old_prompt} --> {prompt}")
|
||||
return prompt, strs
|
||||
|
||||
# def is_chinese(self, text):
|
||||
# text = checker._clean_text(text)
|
||||
# for char in text:
|
||||
# cp = ord(char)
|
||||
# if checker._is_chinese_char(cp):
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def separate_pos_imgs(self, img, sort_priority, gap=102):
|
||||
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
|
||||
components = []
|
||||
for label in range(1, num_labels):
|
||||
component = np.zeros_like(img)
|
||||
component[labels == label] = 255
|
||||
components.append((component, centroids[label]))
|
||||
if sort_priority == "y":
|
||||
fir, sec = 1, 0 # top-down first
|
||||
elif sort_priority == "x":
|
||||
fir, sec = 0, 1 # left-right first
|
||||
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
|
||||
sorted_components = [c[0] for c in components]
|
||||
return sorted_components
|
||||
|
||||
def find_polygon(self, image, min_rect=False):
|
||||
contours, hierarchy = cv2.findContours(
|
||||
image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
||||
)
|
||||
max_contour = max(contours, key=cv2.contourArea) # get contour with max area
|
||||
if min_rect:
|
||||
# get minimum enclosing rectangle
|
||||
rect = cv2.minAreaRect(max_contour)
|
||||
poly = np.int0(cv2.boxPoints(rect))
|
||||
else:
|
||||
# get approximate polygon
|
||||
epsilon = 0.01 * cv2.arcLength(max_contour, True)
|
||||
poly = cv2.approxPolyDP(max_contour, epsilon, True)
|
||||
n, _, xy = poly.shape
|
||||
poly = poly.reshape(n, xy)
|
||||
cv2.drawContours(image, [poly], -1, 255, -1)
|
||||
return poly, image
|
||||
|
||||
def arr2tensor(self, arr, bs):
|
||||
arr = np.transpose(arr, (2, 0, 1))
|
||||
_arr = torch.from_numpy(arr.copy()).float().to(self.device)
|
||||
if self.use_fp16:
|
||||
_arr = _arr.half()
|
||||
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
|
||||
return _arr
|
@ -1,99 +0,0 @@
|
||||
model:
|
||||
target: iopaint.model.anytext.cldm.cldm.ControlLDM
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "img"
|
||||
cond_stage_key: "caption"
|
||||
control_key: "hint"
|
||||
glyph_key: "glyphs"
|
||||
position_key: "positions"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: true # need be true when embedding_manager is valid
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
only_mid_control: False
|
||||
loss_alpha: 0 # perceptual loss, 0.003
|
||||
loss_beta: 0 # ctc loss
|
||||
latin_weight: 1.0 # latin text line may need smaller weigth
|
||||
with_step_weight: true
|
||||
use_vae_upsample: true
|
||||
embedding_manager_config:
|
||||
target: iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
valid: true # v6
|
||||
emb_type: ocr # ocr, vit, conv
|
||||
glyph_channels: 1
|
||||
position_channels: 1
|
||||
add_pos: false
|
||||
placeholder_string: '*'
|
||||
|
||||
control_stage_config:
|
||||
target: iopaint.model.anytext.cldm.cldm.ControlNet
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
model_channels: 320
|
||||
glyph_channels: 1
|
||||
position_channels: 1
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
unet_config:
|
||||
target: iopaint.model.anytext.cldm.cldm.ControlledUnetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
|
||||
params:
|
||||
version: openai/clip-vit-large-patch14
|
||||
use_vision: false # v6
|
@ -1,630 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import copy
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from iopaint.model.anytext.ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
|
||||
from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from .recognizer import TextRecognizer, create_predictor
|
||||
|
||||
CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
||||
hs = []
|
||||
with torch.no_grad():
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
if self.use_fp16:
|
||||
t_emb = t_emb.half()
|
||||
emb = self.time_embed(t_emb)
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
|
||||
if control is not None:
|
||||
h += control.pop()
|
||||
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
if only_mid_control or control is None:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
else:
|
||||
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
glyph_channels,
|
||||
position_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_fp16 = use_fp16
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
|
||||
self.glyph_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, glyph_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, position_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 64, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
if self.use_fp16:
|
||||
t_emb = t_emb.half()
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
# guided_hint from text_info
|
||||
B, C, H, W = x.shape
|
||||
glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
|
||||
positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
|
||||
enc_glyph = self.glyph_block(glyphs, emb, context)
|
||||
enc_pos = self.position_block(positions, emb, context)
|
||||
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
|
||||
|
||||
outs = []
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class ControlLDM(LatentDiffusion):
|
||||
|
||||
def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
|
||||
self.use_fp16 = kwargs.pop('use_fp16', False)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.control_model = instantiate_from_config(control_stage_config)
|
||||
self.control_key = control_key
|
||||
self.glyph_key = glyph_key
|
||||
self.position_key = position_key
|
||||
self.only_mid_control = only_mid_control
|
||||
self.control_scales = [1.0] * 13
|
||||
self.loss_alpha = loss_alpha
|
||||
self.loss_beta = loss_beta
|
||||
self.with_step_weight = with_step_weight
|
||||
self.use_vae_upsample = use_vae_upsample
|
||||
self.latin_weight = latin_weight
|
||||
|
||||
if embedding_manager_config is not None and embedding_manager_config.params.valid:
|
||||
self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
|
||||
for param in self.embedding_manager.embedding_parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
self.embedding_manager = None
|
||||
if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
|
||||
if embedding_manager_config.params.emb_type == 'ocr':
|
||||
self.text_predictor = create_predictor().eval()
|
||||
args = edict()
|
||||
args.rec_image_shape = "3, 48, 320"
|
||||
args.rec_batch_num = 6
|
||||
args.rec_char_dict_path = str(CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt")
|
||||
args.use_fp16 = self.use_fp16
|
||||
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
|
||||
for param in self.text_predictor.parameters():
|
||||
param.requires_grad = False
|
||||
if self.embedding_manager:
|
||||
self.embedding_manager.recog = self.cn_recognizer
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
||||
if self.embedding_manager is None: # fill in full caption
|
||||
self.fill_caption(batch)
|
||||
x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
|
||||
control = batch[self.control_key] # for log_images and loss_alpha, not real control
|
||||
if bs is not None:
|
||||
control = control[:bs]
|
||||
control = control.to(self.device)
|
||||
control = einops.rearrange(control, 'b h w c -> b c h w')
|
||||
control = control.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
inv_mask = batch['inv_mask']
|
||||
if bs is not None:
|
||||
inv_mask = inv_mask[:bs]
|
||||
inv_mask = inv_mask.to(self.device)
|
||||
inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
|
||||
inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
glyphs = batch[self.glyph_key]
|
||||
gly_line = batch['gly_line']
|
||||
positions = batch[self.position_key]
|
||||
n_lines = batch['n_lines']
|
||||
language = batch['language']
|
||||
texts = batch['texts']
|
||||
assert len(glyphs) == len(positions)
|
||||
for i in range(len(glyphs)):
|
||||
if bs is not None:
|
||||
glyphs[i] = glyphs[i][:bs]
|
||||
gly_line[i] = gly_line[i][:bs]
|
||||
positions[i] = positions[i][:bs]
|
||||
n_lines = n_lines[:bs]
|
||||
glyphs[i] = glyphs[i].to(self.device)
|
||||
gly_line[i] = gly_line[i].to(self.device)
|
||||
positions[i] = positions[i].to(self.device)
|
||||
glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
|
||||
gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
|
||||
positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
|
||||
glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
|
||||
gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
|
||||
positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
|
||||
info = {}
|
||||
info['glyphs'] = glyphs
|
||||
info['positions'] = positions
|
||||
info['n_lines'] = n_lines
|
||||
info['language'] = language
|
||||
info['texts'] = texts
|
||||
info['img'] = batch['img'] # nhwc, (-1,1)
|
||||
info['masked_x'] = mx
|
||||
info['gly_line'] = gly_line
|
||||
info['inv_mask'] = inv_mask
|
||||
return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
||||
assert isinstance(cond, dict)
|
||||
diffusion_model = self.model.diffusion_model
|
||||
_cond = torch.cat(cond['c_crossattn'], 1)
|
||||
_hint = torch.cat(cond['c_concat'], 1)
|
||||
if self.use_fp16:
|
||||
x_noisy = x_noisy.half()
|
||||
control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
|
||||
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
||||
eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
|
||||
|
||||
return eps
|
||||
|
||||
def instantiate_embedding_manager(self, config, embedder):
|
||||
model = instantiate_from_config(config, embedder=embedder)
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, N):
|
||||
return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
|
||||
|
||||
def get_learned_conditioning(self, c):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
||||
if self.embedding_manager is not None and c['text_info'] is not None:
|
||||
self.embedding_manager.encode_text(c['text_info'])
|
||||
if isinstance(c, dict):
|
||||
cond_txt = c['c_crossattn'][0]
|
||||
else:
|
||||
cond_txt = c
|
||||
if self.embedding_manager is not None:
|
||||
cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
|
||||
else:
|
||||
cond_txt = self.cond_stage_model.encode(cond_txt)
|
||||
if isinstance(c, dict):
|
||||
c['c_crossattn'][0] = cond_txt
|
||||
else:
|
||||
c = cond_txt
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
else:
|
||||
c = self.cond_stage_model(c)
|
||||
else:
|
||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||
return c
|
||||
|
||||
def fill_caption(self, batch, place_holder='*'):
|
||||
bs = len(batch['n_lines'])
|
||||
cond_list = copy.deepcopy(batch[self.cond_stage_key])
|
||||
for i in range(bs):
|
||||
n_lines = batch['n_lines'][i]
|
||||
if n_lines == 0:
|
||||
continue
|
||||
cur_cap = cond_list[i]
|
||||
for j in range(n_lines):
|
||||
r_txt = batch['texts'][j][i]
|
||||
cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
|
||||
cond_list[i] = cur_cap
|
||||
batch[self.cond_stage_key] = cond_list
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
||||
use_ema_scope=True,
|
||||
**kwargs):
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = dict()
|
||||
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
||||
if self.cond_stage_trainable:
|
||||
with torch.no_grad():
|
||||
c = self.get_learned_conditioning(c)
|
||||
c_crossattn = c["c_crossattn"][0][:N]
|
||||
c_cat = c["c_concat"][0][:N]
|
||||
text_info = c["text_info"]
|
||||
text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
|
||||
text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
|
||||
text_info['positions'] = [i[:N] for i in text_info['positions']]
|
||||
text_info['n_lines'] = text_info['n_lines'][:N]
|
||||
text_info['masked_x'] = text_info['masked_x'][:N]
|
||||
text_info['img'] = text_info['img'][:N]
|
||||
|
||||
N = min(z.shape[0], N)
|
||||
n_row = min(z.shape[0], n_row)
|
||||
log["reconstruction"] = self.decode_first_stage(z)
|
||||
log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
|
||||
log["control"] = c_cat * 2.0 - 1.0
|
||||
log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
|
||||
# get glyph
|
||||
glyph_bs = torch.stack(text_info['glyphs'])
|
||||
glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
|
||||
log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
|
||||
# fill caption
|
||||
if not self.embedding_manager:
|
||||
self.fill_caption(batch)
|
||||
captions = batch[self.cond_stage_key]
|
||||
log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
||||
t = t.to(self.device).long()
|
||||
noise = torch.randn_like(z_start)
|
||||
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||
|
||||
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
||||
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
||||
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||
log["diffusion_row"] = diffusion_grid
|
||||
|
||||
if sample:
|
||||
# get denoise row
|
||||
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
|
||||
batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row"] = denoise_grid
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc_cross = self.get_unconditional_conditioning(N)
|
||||
uc_cat = c_cat # torch.zeros_like(c_cat)
|
||||
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
|
||||
samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
|
||||
batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
)
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
pred_x0 = False # wether log pred_x0
|
||||
if pred_x0:
|
||||
for idx in range(len(tmps['pred_x0'])):
|
||||
pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
|
||||
log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
|
||||
|
||||
return log
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
b, c, h, w = cond["c_concat"][0].shape
|
||||
shape = (self.channels, h // 8, w // 8)
|
||||
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
|
||||
return samples, intermediates
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = list(self.control_model.parameters())
|
||||
if self.embedding_manager:
|
||||
params += list(self.embedding_manager.embedding_parameters())
|
||||
if not self.sd_locked:
|
||||
# params += list(self.model.diffusion_model.input_blocks.parameters())
|
||||
# params += list(self.model.diffusion_model.middle_block.parameters())
|
||||
params += list(self.model.diffusion_model.output_blocks.parameters())
|
||||
params += list(self.model.diffusion_model.out.parameters())
|
||||
if self.unlockKV:
|
||||
nCount = 0
|
||||
for name, param in self.model.diffusion_model.named_parameters():
|
||||
if 'attn2.to_k' in name or 'attn2.to_v' in name:
|
||||
params += [param]
|
||||
nCount += 1
|
||||
print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
|
||||
|
||||
opt = torch.optim.AdamW(params, lr=lr)
|
||||
return opt
|
||||
|
||||
def low_vram_shift(self, is_diffusing):
|
||||
if is_diffusing:
|
||||
self.model = self.model.cuda()
|
||||
self.control_model = self.control_model.cuda()
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.cond_stage_model = self.cond_stage_model.cpu()
|
||||
else:
|
||||
self.model = self.model.cpu()
|
||||
self.control_model = self.control_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.cuda()
|
||||
self.cond_stage_model = self.cond_stage_model.cuda()
|
@ -1,486 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, device, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f"Data shape for DDIM sampling is {size}, eta {eta}")
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(None, i, None, None)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
model_t = self.model.apply_model(x, t, c)
|
||||
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (
|
||||
model_t - model_uncond
|
||||
)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", "not implemented"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(
|
||||
self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
callback=None,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
num_reference_steps = timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc="Encoding Image"):
|
||||
t = torch.full(
|
||||
(x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
|
||||
)
|
||||
if unconditional_guidance_scale == 1.0:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)),
|
||||
torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c)),
|
||||
),
|
||||
2,
|
||||
)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
||||
noise_pred - e_t_uncond
|
||||
)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = (
|
||||
alphas_next[i].sqrt()
|
||||
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
|
||||
* noise_pred
|
||||
)
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if (
|
||||
return_intermediates
|
||||
and i % (num_steps // return_intermediates) == 0
|
||||
and i < num_steps - 1
|
||||
):
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback:
|
||||
callback(i)
|
||||
|
||||
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({"intermediates": intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
||||
)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
if callback:
|
||||
callback(i)
|
||||
return x_dec
|
@ -1,165 +0,0 @@
|
||||
'''
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"]
|
||||
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
token = token[0, 1]
|
||||
return token
|
||||
|
||||
|
||||
def get_clip_vision_emb(encoder, processor, img):
|
||||
_img = img.repeat(1, 3, 1, 1)*255
|
||||
inputs = processor(images=_img, return_tensors="pt")
|
||||
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
|
||||
outputs = encoder(**inputs)
|
||||
emb = outputs.image_embeds
|
||||
return emb
|
||||
|
||||
|
||||
def get_recog_emb(encoder, img_list):
|
||||
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
|
||||
encoder.predictor.eval()
|
||||
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
|
||||
return preds_neck
|
||||
|
||||
|
||||
def pad_H(x):
|
||||
_, _, H, W = x.shape
|
||||
p_top = (W - H) // 2
|
||||
p_bot = W - H - p_top
|
||||
return F.pad(x, (0, 0, p_top, p_bot))
|
||||
|
||||
|
||||
class EncodeNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(EncodeNet, self).__init__()
|
||||
chan = 16
|
||||
n_layer = 4 # downsample
|
||||
|
||||
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
|
||||
self.conv_list = nn.ModuleList([])
|
||||
_c = chan
|
||||
for i in range(n_layer):
|
||||
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
|
||||
_c *= 2
|
||||
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act(self.conv1(x))
|
||||
for layer in self.conv_list:
|
||||
x = self.act(layer(x))
|
||||
x = self.act(self.conv2(x))
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
valid=True,
|
||||
glyph_channels=20,
|
||||
position_channels=1,
|
||||
placeholder_string='*',
|
||||
add_pos=False,
|
||||
emb_type='ocr',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
|
||||
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
||||
token_dim = 768
|
||||
if hasattr(embedder, 'vit'):
|
||||
assert emb_type == 'vit'
|
||||
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
|
||||
self.get_recog_emb = None
|
||||
else: # using LDM's BERT encoder
|
||||
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
|
||||
token_dim = 1280
|
||||
self.token_dim = token_dim
|
||||
self.emb_type = emb_type
|
||||
|
||||
self.add_pos = add_pos
|
||||
if add_pos:
|
||||
self.position_encoder = EncodeNet(position_channels, token_dim)
|
||||
if emb_type == 'ocr':
|
||||
self.proj = linear(40*64, token_dim)
|
||||
if emb_type == 'conv':
|
||||
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
|
||||
|
||||
self.placeholder_token = get_token_for_string(placeholder_string)
|
||||
|
||||
def encode_text(self, text_info):
|
||||
if self.get_recog_emb is None and self.emb_type == 'ocr':
|
||||
self.get_recog_emb = partial(get_recog_emb, self.recog)
|
||||
|
||||
gline_list = []
|
||||
pos_list = []
|
||||
for i in range(len(text_info['n_lines'])): # sample index in a batch
|
||||
n_lines = text_info['n_lines'][i]
|
||||
for j in range(n_lines): # line
|
||||
gline_list += [text_info['gly_line'][j][i:i+1]]
|
||||
if self.add_pos:
|
||||
pos_list += [text_info['positions'][j][i:i+1]]
|
||||
|
||||
if len(gline_list) > 0:
|
||||
if self.emb_type == 'ocr':
|
||||
recog_emb = self.get_recog_emb(gline_list)
|
||||
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
|
||||
elif self.emb_type == 'vit':
|
||||
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
|
||||
elif self.emb_type == 'conv':
|
||||
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
|
||||
if self.add_pos:
|
||||
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
|
||||
enc_glyph = enc_glyph+enc_pos
|
||||
|
||||
self.text_embs_all = []
|
||||
n_idx = 0
|
||||
for i in range(len(text_info['n_lines'])): # sample index in a batch
|
||||
n_lines = text_info['n_lines'][i]
|
||||
text_embs = []
|
||||
for j in range(n_lines): # line
|
||||
text_embs += [enc_glyph[n_idx:n_idx+1]]
|
||||
n_idx += 1
|
||||
self.text_embs_all += [text_embs]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokenized_text,
|
||||
embedded_text,
|
||||
):
|
||||
b, device = tokenized_text.shape[0], tokenized_text.device
|
||||
for i in range(b):
|
||||
idx = tokenized_text[i] == self.placeholder_token.to(device)
|
||||
if sum(idx) > 0:
|
||||
if i >= len(self.text_embs_all):
|
||||
print('truncation for log images...')
|
||||
break
|
||||
text_emb = torch.cat(self.text_embs_all[i], dim=0)
|
||||
if sum(idx) != len(text_emb):
|
||||
print('truncation for long caption...')
|
||||
embedded_text[i][idx] = text_emb[:sum(idx)]
|
||||
return embedded_text
|
||||
|
||||
def embedding_parameters(self):
|
||||
return self.parameters()
|
@ -1,111 +0,0 @@
|
||||
import torch
|
||||
import einops
|
||||
|
||||
import iopaint.model.anytext.ldm.modules.encoders.modules
|
||||
import iopaint.model.anytext.ldm.modules.attention
|
||||
|
||||
from transformers import logging
|
||||
from iopaint.model.anytext.ldm.modules.attention import default
|
||||
|
||||
|
||||
def disable_verbosity():
|
||||
logging.set_verbosity_error()
|
||||
print('logging improved.')
|
||||
return
|
||||
|
||||
|
||||
def enable_sliced_attention():
|
||||
iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
|
||||
print('Enabled sliced_attention.')
|
||||
return
|
||||
|
||||
|
||||
def hack_everything(clip_skip=0):
|
||||
disable_verbosity()
|
||||
iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
|
||||
iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
|
||||
print('Enabled clip hacks.')
|
||||
return
|
||||
|
||||
|
||||
# Written by Lvmin
|
||||
def _hacked_clip_forward(self, text):
|
||||
PAD = self.tokenizer.pad_token_id
|
||||
EOS = self.tokenizer.eos_token_id
|
||||
BOS = self.tokenizer.bos_token_id
|
||||
|
||||
def tokenize(t):
|
||||
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def transformer_encode(t):
|
||||
if self.clip_skip > 1:
|
||||
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
||||
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
|
||||
else:
|
||||
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
|
||||
|
||||
def split(x):
|
||||
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
|
||||
|
||||
def pad(x, p, i):
|
||||
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
||||
|
||||
raw_tokens_list = tokenize(text)
|
||||
tokens_list = []
|
||||
|
||||
for raw_tokens in raw_tokens_list:
|
||||
raw_tokens_123 = split(raw_tokens)
|
||||
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
|
||||
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
||||
tokens_list.append(raw_tokens_123)
|
||||
|
||||
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
||||
|
||||
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
|
||||
y = transformer_encode(feed)
|
||||
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
|
||||
|
||||
return z
|
||||
|
||||
|
||||
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
||||
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
limit = k.shape[0]
|
||||
att_step = 1
|
||||
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
||||
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range(0, limit, att_step):
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i + att_step, :, :] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
@ -1,40 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def get_state_dict(d):
|
||||
return d.get("state_dict", d)
|
||||
|
||||
|
||||
def load_state_dict(ckpt_path, location="cpu"):
|
||||
_, extension = os.path.splitext(ckpt_path)
|
||||
if extension.lower() == ".safetensors":
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
||||
else:
|
||||
state_dict = get_state_dict(
|
||||
torch.load(ckpt_path, map_location=torch.device(location))
|
||||
)
|
||||
state_dict = get_state_dict(state_dict)
|
||||
print(f"Loaded state_dict from [{ckpt_path}]")
|
||||
return state_dict
|
||||
|
||||
|
||||
def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
|
||||
config = OmegaConf.load(config_path)
|
||||
# if cond_stage_path:
|
||||
# config.model.params.cond_stage_config.params.version = (
|
||||
# cond_stage_path # use pre-downloaded ckpts, in case blocked
|
||||
# )
|
||||
config.model.params.cond_stage_config.params.device = str(device)
|
||||
if use_fp16:
|
||||
config.model.params.use_fp16 = True
|
||||
config.model.params.control_stage_config.params.use_fp16 = True
|
||||
config.model.params.unet_config.params.use_fp16 = True
|
||||
model = instantiate_from_config(config.model).cpu()
|
||||
print(f"Loaded model config from [{config_path}]")
|
||||
return model
|
@ -1,300 +0,0 @@
|
||||
"""
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import traceback
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
from iopaint.model.anytext.ocr_recog.RecModel import RecModel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def min_bounding_rect(img):
|
||||
ret, thresh = cv2.threshold(img, 127, 255, 0)
|
||||
contours, hierarchy = cv2.findContours(
|
||||
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
if len(contours) == 0:
|
||||
print("Bad contours, using fake bbox...")
|
||||
return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
|
||||
max_contour = max(contours, key=cv2.contourArea)
|
||||
rect = cv2.minAreaRect(max_contour)
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
# sort
|
||||
x_sorted = sorted(box, key=lambda x: x[0])
|
||||
left = x_sorted[:2]
|
||||
right = x_sorted[2:]
|
||||
left = sorted(left, key=lambda x: x[1])
|
||||
(tl, bl) = left
|
||||
right = sorted(right, key=lambda x: x[1])
|
||||
(tr, br) = right
|
||||
if tl[1] > bl[1]:
|
||||
(tl, bl) = (bl, tl)
|
||||
if tr[1] > br[1]:
|
||||
(tr, br) = (br, tr)
|
||||
return np.array([tl, tr, br, bl])
|
||||
|
||||
|
||||
def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
|
||||
model_file_path = model_dir
|
||||
if model_file_path is not None and not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(model_file_path))
|
||||
|
||||
if is_onnx:
|
||||
import onnxruntime as ort
|
||||
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path, providers=["CPUExecutionProvider"]
|
||||
) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
return sess
|
||||
else:
|
||||
if model_lang == "ch":
|
||||
n_class = 6625
|
||||
elif model_lang == "en":
|
||||
n_class = 97
|
||||
else:
|
||||
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
|
||||
rec_config = edict(
|
||||
in_channels=3,
|
||||
backbone=edict(
|
||||
type="MobileNetV1Enhance",
|
||||
scale=0.5,
|
||||
last_conv_stride=[1, 2],
|
||||
last_pool_type="avg",
|
||||
),
|
||||
neck=edict(
|
||||
type="SequenceEncoder",
|
||||
encoder_type="svtr",
|
||||
dims=64,
|
||||
depth=2,
|
||||
hidden_dims=120,
|
||||
use_guide=True,
|
||||
),
|
||||
head=edict(
|
||||
type="CTCHead",
|
||||
fc_decay=0.00001,
|
||||
out_channels=n_class,
|
||||
return_feats=True,
|
||||
),
|
||||
)
|
||||
|
||||
rec_model = RecModel(rec_config)
|
||||
if model_file_path is not None:
|
||||
rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
|
||||
rec_model.eval()
|
||||
return rec_model.eval()
|
||||
|
||||
|
||||
def _check_image_file(path):
|
||||
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
|
||||
return any([path.lower().endswith(e) for e in img_end])
|
||||
|
||||
|
||||
def get_image_file_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
if os.path.isfile(img_file) and _check_image_file(img_file):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
class TextRecognizer(object):
|
||||
def __init__(self, args, predictor):
|
||||
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.predictor = predictor
|
||||
self.chars = self.get_char_dict(args.rec_char_dict_path)
|
||||
self.char2id = {x: i for i, x in enumerate(self.chars)}
|
||||
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
|
||||
self.use_fp16 = args.use_fp16
|
||||
|
||||
# img: CHW
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
assert imgC == img.shape[0]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
|
||||
h, w = img.shape[1:]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = torch.nn.functional.interpolate(
|
||||
img.unsqueeze(0),
|
||||
size=(imgH, resized_w),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
resized_image /= 255.0
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
|
||||
padding_im[:, :, 0:resized_w] = resized_image[0]
|
||||
return padding_im
|
||||
|
||||
# img_list: list of tensors with shape chw 0-255
|
||||
def pred_imglist(self, img_list, show_debug=False, is_ori=False):
|
||||
img_num = len(img_list)
|
||||
assert img_num > 0
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[2] / float(img.shape[1]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = torch.from_numpy(np.argsort(np.array(width_list)))
|
||||
batch_num = self.rec_batch_num
|
||||
preds_all = [None] * img_num
|
||||
preds_neck_all = [None] * img_num
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
|
||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||
max_wh_ratio = imgW / imgH
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[1:]
|
||||
if h > w * 1.2:
|
||||
img = img_list[indices[ino]]
|
||||
img = torch.transpose(img, 1, 2).flip(dims=[1])
|
||||
img_list[indices[ino]] = img
|
||||
h, w = img.shape[1:]
|
||||
# wh_ratio = w * 1.0 / h
|
||||
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
||||
if self.use_fp16:
|
||||
norm_img = norm_img.half()
|
||||
norm_img = norm_img.unsqueeze(0)
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = torch.cat(norm_img_batch, dim=0)
|
||||
if show_debug:
|
||||
for i in range(len(norm_img_batch)):
|
||||
_img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
|
||||
_img = (_img + 0.5) * 255
|
||||
_img = _img[:, :, ::-1]
|
||||
file_name = f"{indices[beg_img_no + i]}"
|
||||
file_name = file_name + "_ori" if is_ori else file_name
|
||||
cv2.imwrite(file_name + ".jpg", _img)
|
||||
if self.is_onnx:
|
||||
input_dict = {}
|
||||
input_dict[self.predictor.get_inputs()[0].name] = (
|
||||
norm_img_batch.detach().cpu().numpy()
|
||||
)
|
||||
outputs = self.predictor.run(None, input_dict)
|
||||
preds = {}
|
||||
preds["ctc"] = torch.from_numpy(outputs[0])
|
||||
preds["ctc_neck"] = [torch.zeros(1)] * img_num
|
||||
else:
|
||||
preds = self.predictor(norm_img_batch)
|
||||
for rno in range(preds["ctc"].shape[0]):
|
||||
preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
|
||||
preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
|
||||
|
||||
return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
|
||||
|
||||
def get_char_dict(self, character_dict_path):
|
||||
character_str = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
||||
character_str.append(line)
|
||||
dict_character = list(character_str)
|
||||
dict_character = ["sos"] + dict_character + [" "] # eos is space
|
||||
return dict_character
|
||||
|
||||
def get_text(self, order):
|
||||
char_list = [self.chars[text_id] for text_id in order]
|
||||
return "".join(char_list)
|
||||
|
||||
def decode(self, mat):
|
||||
text_index = mat.detach().cpu().numpy().argmax(axis=1)
|
||||
ignored_tokens = [0]
|
||||
selection = np.ones(len(text_index), dtype=bool)
|
||||
selection[1:] = text_index[1:] != text_index[:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index != ignored_token
|
||||
return text_index[selection], np.where(selection)[0]
|
||||
|
||||
def get_ctcloss(self, preds, gt_text, weight):
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
weight = torch.tensor(weight).to(preds.device)
|
||||
ctc_loss = torch.nn.CTCLoss(reduction="none")
|
||||
log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
|
||||
targets = []
|
||||
target_lengths = []
|
||||
for t in gt_text:
|
||||
targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
|
||||
target_lengths += [len(t)]
|
||||
targets = torch.tensor(targets).to(preds.device)
|
||||
target_lengths = torch.tensor(target_lengths).to(preds.device)
|
||||
input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
|
||||
preds.device
|
||||
)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
loss = loss / input_lengths * weight
|
||||
return loss
|
||||
|
||||
|
||||
def main():
|
||||
rec_model_dir = "./ocr_weights/ppv3_rec.pth"
|
||||
predictor = create_predictor(rec_model_dir)
|
||||
args = edict()
|
||||
args.rec_image_shape = "3, 48, 320"
|
||||
args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
|
||||
args.rec_batch_num = 6
|
||||
text_recognizer = TextRecognizer(args, predictor)
|
||||
image_dir = "./test_imgs_cn"
|
||||
gt_text = ["韩国小馆"] * 14
|
||||
|
||||
image_file_list = get_image_file_list(image_dir)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
|
||||
for image_file in image_file_list:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
print("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
|
||||
try:
|
||||
tic = time.time()
|
||||
times = []
|
||||
for i in range(10):
|
||||
preds, _ = text_recognizer.pred_imglist(img_list) # get text
|
||||
preds_all = preds.softmax(dim=2)
|
||||
times += [(time.time() - tic) * 1000.0]
|
||||
tic = time.time()
|
||||
print(times)
|
||||
print(np.mean(times[1:]) / len(preds_all))
|
||||
weight = np.ones(len(gt_text))
|
||||
loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
|
||||
for i in range(len(valid_image_file_list)):
|
||||
pred = preds_all[i]
|
||||
order, idx = text_recognizer.decode(pred)
|
||||
text = text_recognizer.get_text(order)
|
||||
print(
|
||||
f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
|
||||
)
|
||||
except Exception as E:
|
||||
print(traceback.format_exc(), E)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,218 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
||||
from iopaint.model.anytext.ldm.modules.ema import LitEma
|
||||
|
||||
|
||||
class AutoencoderKL(torch.nn.Module):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
if self.use_ema:
|
||||
self.ema_decay = ema_decay
|
||||
assert 0. < ema_decay < 1.
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix=""):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||
|
||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
||||
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
||||
if self.learn_logvar:
|
||||
print(f"{self.__class__.__name__}: Learning logvar")
|
||||
ae_params_list.append(self.loss.logvar)
|
||||
opt_ae = torch.optim.Adam(ae_params_list,
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
@ -1,354 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
# cbs = len(ctmp[0])
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img], "index": [10000]}
|
||||
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['index'].append(index)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
elif isinstance(c[k], dict):
|
||||
c_in[k] = dict()
|
||||
for key in c[k]:
|
||||
if isinstance(c[k][key], list):
|
||||
if not isinstance(c[k][key][0], torch.Tensor):
|
||||
continue
|
||||
c_in[k][key] = [torch.cat([
|
||||
unconditional_conditioning[k][key][i],
|
||||
c[k][key][i]]) for i in range(len(c[k][key]))]
|
||||
else:
|
||||
c_in[k][key] = torch.cat([
|
||||
unconditional_conditioning[k][key],
|
||||
c[k][key]])
|
||||
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
elif isinstance(c, list):
|
||||
c_in = list()
|
||||
assert isinstance(unconditional_conditioning, list)
|
||||
for i in range(len(c)):
|
||||
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (
|
||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback: callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
from .sampler import DPMSolverSampler
|
File diff suppressed because it is too large
Load Diff
@ -1,87 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
}
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
ns,
|
||||
model_type=MODEL_TYPES[self.model.parameterization],
|
||||
guidance_type="classifier-free",
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
@ -1,244 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from iopaint.model.anytext.ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
@ -1,22 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
@ -1,360 +0,0 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
|
||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION == "fp32":
|
||||
with torch.autocast(enabled=False, device_type="cuda"):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
else:
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b i j, b j d -> b i d", sim, v)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SDPACrossAttention(CrossAttention):
|
||||
def forward(self, x, context=None, mask=None):
|
||||
batch_size, sequence_length, inner_dim = x.shape
|
||||
|
||||
if mask is not None:
|
||||
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
||||
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
||||
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
|
||||
head_dim = inner_dim // h
|
||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if _ATTN_PRECISION == "fp32":
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, h * head_dim
|
||||
)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
attn_cls = SDPACrossAttention
|
||||
else:
|
||||
attn_cls = CrossAttention
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None,
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x), context=context if self.disable_self_attn else None
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
use_linear=False,
|
||||
use_checkpoint=True,
|
||||
):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim]
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context[i])
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
@ -1,973 +0,0 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class AttnBlock2_0(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
# output: [1, 512, 64, 64]
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
|
||||
# q = q.reshape(b, c, h * w).transpose()
|
||||
# q = q.permute(0, 2, 1) # b,hw,c
|
||||
# k = k.reshape(b, c, h * w) # b,c,hw
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
# (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
h_ = self.proj_out(hidden_states)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in [
|
||||
"vanilla",
|
||||
"vanilla-xformers",
|
||||
"memory-efficient-cross-attn",
|
||||
"linear",
|
||||
"none",
|
||||
], f"attn_type {attn_type} unknown"
|
||||
assert attn_kwargs is None
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
# print(f"Using torch.nn.functional.scaled_dot_product_attention")
|
||||
return AttnBlock2_0(in_channels)
|
||||
return AttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb
|
||||
)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0,
|
||||
),
|
||||
ResnetBlock(
|
||||
in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0,
|
||||
),
|
||||
ResnetBlock(
|
||||
in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0,
|
||||
),
|
||||
nn.Conv2d(2 * in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True),
|
||||
]
|
||||
)
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1, 2, 3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
ch,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
ch_mult=(2, 2),
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, mid_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.res_block1 = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(
|
||||
int(round(x.shape[2] * self.factor)),
|
||||
int(round(x.shape[3] * self.factor)),
|
||||
),
|
||||
)
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
ch,
|
||||
resolution,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1,
|
||||
):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn,
|
||||
double_z=False,
|
||||
resolution=resolution,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None,
|
||||
)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=rescale_factor,
|
||||
in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn,
|
||||
out_channels=out_ch,
|
||||
depth=rescale_module_depth,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels,
|
||||
out_ch,
|
||||
resolution,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1,
|
||||
):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels * ch_mult[-1]
|
||||
self.decoder = Decoder(
|
||||
out_ch=out_ch,
|
||||
z_channels=tmp_chn,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=None,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult,
|
||||
resolution=resolution,
|
||||
ch=ch,
|
||||
)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=rescale_factor,
|
||||
in_channels=z_channels,
|
||||
mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn,
|
||||
depth=rescale_module_depth,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size // in_size)) + 1
|
||||
factor_up = 1.0 + (out_size % in_size)
|
||||
print(
|
||||
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
|
||||
)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=factor_up,
|
||||
in_channels=in_channels,
|
||||
mid_channels=2 * in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
out_ch=out_channels,
|
||||
resolution=out_size,
|
||||
z_channels=in_channels,
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=[],
|
||||
in_channels=None,
|
||||
ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(
|
||||
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
|
||||
)
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor == 1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
||||
)
|
||||
return x
|
@ -1,786 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
|
||||
from iopaint.model.anytext.ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
|
||||
def forward(self,x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
self.use_fp16 = use_fp16
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
@ -1,81 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from iopaint.model.anytext.ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
# for concatenating a downsampled image to the latent representation
|
||||
def __init__(self, noise_schedule_config=None):
|
||||
super(AbstractLowScaleModel, self).__init__()
|
||||
if noise_schedule_config is not None:
|
||||
self.register_schedule(**noise_schedule_config)
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageConcat(AbstractLowScaleModel):
|
||||
# no noise level conditioning
|
||||
def __init__(self):
|
||||
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||
self.max_noise_level = 0
|
||||
|
||||
def forward(self, x):
|
||||
# fix to constant noise level
|
||||
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||
|
||||
|
||||
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
|
@ -1,271 +0,0 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
# return super().forward(x.float()).type(x.dtype)
|
||||
return super().forward(x).type(x.dtype)
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
@ -1,92 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
@ -1,80 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def reset_num_updates(self):
|
||||
del self.num_updates
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
@ -1,411 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers import (
|
||||
T5Tokenizer,
|
||||
T5EncoderModel,
|
||||
CLIPTokenizer,
|
||||
CLIPTextModel,
|
||||
AutoProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from iopaint.model.anytext.ldm.util import count_params
|
||||
|
||||
|
||||
def _expand_mask(mask, dtype, tgt_len=None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
self.n_classes = n_classes
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0.0 and not disable_dropout:
|
||||
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = (
|
||||
self.n_classes - 1
|
||||
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs,), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(
|
||||
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = layer_idx
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert 0 <= abs(layer_idx) <= 12
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(
|
||||
input_ids=tokens, output_hidden_states=self.layer == "hidden"
|
||||
)
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
clip_version="openai/clip-vit-large-patch14",
|
||||
t5_version="google/t5-v1_1-xl",
|
||||
device="cuda",
|
||||
clip_max_length=77,
|
||||
t5_max_length=77,
|
||||
):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(
|
||||
clip_version, device, max_length=clip_max_length
|
||||
)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
def forward(self, text):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
use_vision=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
if use_vision:
|
||||
self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
|
||||
self.processor = AutoProcessor.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def embedding_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
seq_length = (
|
||||
input_ids.shape[-1]
|
||||
if input_ids is not None
|
||||
else inputs_embeds.shape[-2]
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
if embedding_manager is not None:
|
||||
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
return embeddings
|
||||
|
||||
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
|
||||
self.transformer.text_model.embeddings
|
||||
)
|
||||
|
||||
def encoder_forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask=None,
|
||||
causal_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
return hidden_states
|
||||
|
||||
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
|
||||
self.transformer.text_model.encoder
|
||||
)
|
||||
|
||||
def text_encoder_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
bsz, seq_len = input_shape
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _build_causal_attention_mask(
|
||||
bsz, seq_len, hidden_states.dtype
|
||||
).to(hidden_states.device)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
last_hidden_state = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
return last_hidden_state
|
||||
|
||||
self.transformer.text_model.forward = text_encoder_forward.__get__(
|
||||
self.transformer.text_model
|
||||
)
|
||||
|
||||
def transformer_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
self.transformer.forward = transformer_forward.__get__(self.transformer)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
return z
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
@ -1,197 +0,0 @@
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
import numpy as np
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
|
||||
nc = int(32 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if "target" not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
@ -1,45 +0,0 @@
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from anytext_pipeline import AnyTextPipeline
|
||||
from utils import save_images
|
||||
|
||||
seed = 66273235
|
||||
# seed_everything(seed)
|
||||
|
||||
pipe = AnyTextPipeline(
|
||||
ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt",
|
||||
font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
|
||||
use_fp16=False,
|
||||
device="mps",
|
||||
)
|
||||
|
||||
img_save_folder = "SaveImages"
|
||||
rgb_image = cv2.imread(
|
||||
"/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg"
|
||||
)[..., ::-1]
|
||||
|
||||
masked_image = cv2.imread(
|
||||
"/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png"
|
||||
)[..., ::-1]
|
||||
|
||||
rgb_image = cv2.resize(rgb_image, (512, 512))
|
||||
masked_image = cv2.resize(masked_image, (512, 512))
|
||||
|
||||
# results: list of rgb ndarray
|
||||
results, rtn_code, rtn_warning = pipe(
|
||||
prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
|
||||
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
image=rgb_image,
|
||||
masked_image=masked_image,
|
||||
num_inference_steps=20,
|
||||
strength=1.0,
|
||||
guidance_scale=9.0,
|
||||
height=rgb_image.shape[0],
|
||||
width=rgb_image.shape[1],
|
||||
seed=seed,
|
||||
sort_priority="y",
|
||||
)
|
||||
if rtn_code >= 0:
|
||||
save_images(results, img_save_folder)
|
||||
print(f"Done, result images are saved in: {img_save_folder}")
|
@ -1,210 +0,0 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .RecSVTR import Block
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self,x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
class Im2Im(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class Im2Seq(nn.Module):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# assert H == 1
|
||||
x = x.reshape(B, C, H * W)
|
||||
x = x.permute((0, 2, 1))
|
||||
return x
|
||||
|
||||
class EncoderWithRNN(nn.Module):
|
||||
def __init__(self, in_channels,**kwargs):
|
||||
super(EncoderWithRNN, self).__init__()
|
||||
hidden_size = kwargs.get('hidden_size', 256)
|
||||
self.out_channels = hidden_size * 2
|
||||
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
class SequenceEncoder(nn.Module):
|
||||
def __init__(self, in_channels, encoder_type='rnn', **kwargs):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
self.encoder_reshape = Im2Seq(in_channels)
|
||||
self.out_channels = self.encoder_reshape.out_channels
|
||||
self.encoder_type = encoder_type
|
||||
if encoder_type == 'reshape':
|
||||
self.only_reshape = True
|
||||
else:
|
||||
support_encoder_dict = {
|
||||
'reshape': Im2Seq,
|
||||
'rnn': EncoderWithRNN,
|
||||
'svtr': EncoderWithSVTR
|
||||
}
|
||||
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
|
||||
encoder_type, support_encoder_dict.keys())
|
||||
|
||||
self.encoder = support_encoder_dict[encoder_type](
|
||||
self.encoder_reshape.out_channels,**kwargs)
|
||||
self.out_channels = self.encoder.out_channels
|
||||
self.only_reshape = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.encoder_type != 'svtr':
|
||||
x = self.encoder_reshape(x)
|
||||
if not self.only_reshape:
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
else:
|
||||
x = self.encoder(x)
|
||||
x = self.encoder_reshape(x)
|
||||
return x
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False,
|
||||
groups=1,
|
||||
act=nn.GELU):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = Swish()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class EncoderWithSVTR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
dims=64, # XS
|
||||
depth=2,
|
||||
hidden_dims=120,
|
||||
use_guide=False,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
mlp_ratio=2.0,
|
||||
drop_rate=0.1,
|
||||
attn_drop_rate=0.1,
|
||||
drop_path=0.,
|
||||
qk_scale=None):
|
||||
super(EncoderWithSVTR, self).__init__()
|
||||
self.depth = depth
|
||||
self.use_guide = use_guide
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels, in_channels // 8, padding=1, act='swish')
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels // 8, hidden_dims, kernel_size=1, act='swish')
|
||||
|
||||
self.svtr_block = nn.ModuleList([
|
||||
Block(
|
||||
dim=hidden_dims,
|
||||
num_heads=num_heads,
|
||||
mixer='Global',
|
||||
HW=None,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer='swish',
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-05,
|
||||
prenorm=False) for i in range(depth)
|
||||
])
|
||||
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
|
||||
self.conv3 = ConvBNLayer(
|
||||
hidden_dims, in_channels, kernel_size=1, act='swish')
|
||||
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
|
||||
self.conv4 = ConvBNLayer(
|
||||
2 * in_channels, in_channels // 8, padding=1, act='swish')
|
||||
|
||||
self.conv1x1 = ConvBNLayer(
|
||||
in_channels // 8, dims, kernel_size=1, act='swish')
|
||||
self.out_channels = dims
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# weight initialization
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# for use guide
|
||||
if self.use_guide:
|
||||
z = x.clone()
|
||||
z.stop_gradient = True
|
||||
else:
|
||||
z = x
|
||||
# for short cut
|
||||
h = z
|
||||
# reduce dim
|
||||
z = self.conv1(z)
|
||||
z = self.conv2(z)
|
||||
# SVTR global block
|
||||
B, C, H, W = z.shape
|
||||
z = z.flatten(2).permute(0, 2, 1)
|
||||
|
||||
for blk in self.svtr_block:
|
||||
z = blk(z)
|
||||
|
||||
z = self.norm(z)
|
||||
# last stage
|
||||
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
|
||||
z = self.conv3(z)
|
||||
z = torch.cat((h, z), dim=1)
|
||||
z = self.conv1x1(self.conv4(z))
|
||||
|
||||
return z
|
||||
|
||||
if __name__=="__main__":
|
||||
svtrRNN = EncoderWithSVTR(56)
|
||||
print(svtrRNN)
|
@ -1,48 +0,0 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CTCHead(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=6625,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
return_feats=False,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,)
|
||||
else:
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
bias=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = dict()
|
||||
result['ctc'] = predicts
|
||||
result['ctc_neck'] = x
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
return result
|
@ -1,45 +0,0 @@
|
||||
from torch import nn
|
||||
from .RNN import SequenceEncoder, Im2Seq, Im2Im
|
||||
from .RecMv1_enhance import MobileNetV1Enhance
|
||||
|
||||
from .RecCTCHead import CTCHead
|
||||
|
||||
backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
|
||||
neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
|
||||
head_dict = {'CTCHead':CTCHead}
|
||||
|
||||
|
||||
class RecModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert 'in_channels' in config, 'in_channels must in model config'
|
||||
backbone_type = config.backbone.pop('type')
|
||||
assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
|
||||
self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
|
||||
|
||||
neck_type = config.neck.pop('type')
|
||||
assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
|
||||
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
|
||||
|
||||
head_type = config.head.pop('type')
|
||||
assert head_type in head_dict, f'head.type must in {head_dict}'
|
||||
self.head = head_dict[head_type](self.neck.out_channels, **config.head)
|
||||
|
||||
self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
|
||||
|
||||
def load_3rd_state_dict(self, _3rd_name, _state):
|
||||
self.backbone.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.neck.load_3rd_state_dict(_3rd_name, _state)
|
||||
self.head.load_3rd_state_dict(_3rd_name, _state)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.neck(x)
|
||||
x = self.head.ctc_encoder(x)
|
||||
return x
|
@ -1,232 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .common import Activation
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
num_groups=1,
|
||||
act='hard_swish'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.act = act
|
||||
self._conv = nn.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
bias=False)
|
||||
|
||||
self._batch_norm = nn.BatchNorm2d(
|
||||
num_filters,
|
||||
)
|
||||
if self.act is not None:
|
||||
self._act = Activation(act_type=act, inplace=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
if self.act is not None:
|
||||
y = self._act(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale,
|
||||
dw_size=3,
|
||||
padding=1,
|
||||
use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale))
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
scale=0.5,
|
||||
last_conv_stride=1,
|
||||
last_pool_type='max',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=in_channels,
|
||||
filter_size=3,
|
||||
channels=3,
|
||||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale),
|
||||
num_filters1=32,
|
||||
num_filters2=64,
|
||||
num_groups=32,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale),
|
||||
num_filters1=64,
|
||||
num_filters2=128,
|
||||
num_groups=64,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=128,
|
||||
num_groups=128,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=256,
|
||||
num_groups=256,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=last_conv_stride,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
if last_pool_type == 'avg':
|
||||
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
def hardsigmoid(x):
|
||||
return F.relu6(x + 3., inplace=True) / 6.
|
||||
|
||||
class SEModule(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
x = torch.mul(inputs, outputs)
|
||||
|
||||
return x
|
@ -1,591 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torch.nn.init import trunc_normal_, zeros_, ones_
|
||||
from torch.nn import functional
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0., training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = torch.tensor(1 - drop_prob)
|
||||
shape = (x.size()[0], ) + (1, ) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
|
||||
random_tensor = torch.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __int__(self):
|
||||
super(Swish, self).__int__()
|
||||
|
||||
def forward(self,x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False,
|
||||
groups=1,
|
||||
act=nn.GELU):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias=bias_attr)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
if isinstance(act_layer, str):
|
||||
self.act = Swish()
|
||||
else:
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvMixer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
HW=(8, 25),
|
||||
local_k=(3, 3), ):
|
||||
super().__init__()
|
||||
self.HW = HW
|
||||
self.dim = dim
|
||||
self.local_mixer = nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
local_k,
|
||||
1, (local_k[0] // 2, local_k[1] // 2),
|
||||
groups=num_heads,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.HW[0]
|
||||
w = self.HW[1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
|
||||
x = self.local_mixer(x)
|
||||
x = x.flatten(2).transpose([0, 2, 1])
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
mixer='Global',
|
||||
HW=(8, 25),
|
||||
local_k=(7, 11),
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.HW = HW
|
||||
if HW is not None:
|
||||
H = HW[0]
|
||||
W = HW[1]
|
||||
self.N = H * W
|
||||
self.C = dim
|
||||
if mixer == 'Local' and HW is not None:
|
||||
hk = local_k[0]
|
||||
wk = local_k[1]
|
||||
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
|
||||
for h in range(0, H):
|
||||
for w in range(0, W):
|
||||
mask[h * W + w, h:h + hk, w:w + wk] = 0.
|
||||
mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
|
||||
2].flatten(1)
|
||||
mask_inf = torch.full([H * W, H * W],fill_value=float('-inf'))
|
||||
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
|
||||
self.mask = mask[None,None,:]
|
||||
# self.mask = mask.unsqueeze([0, 1])
|
||||
self.mixer = mixer
|
||||
|
||||
def forward(self, x):
|
||||
if self.HW is not None:
|
||||
N = self.N
|
||||
C = self.C
|
||||
else:
|
||||
_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = (q.matmul(k.permute((0, 1, 3, 2))))
|
||||
if self.mixer == 'Local':
|
||||
attn += self.mask
|
||||
attn = functional.softmax(attn, dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
mixer='Global',
|
||||
local_mixer=(7, 11),
|
||||
HW=(8, 25),
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer='nn.LayerNorm',
|
||||
epsilon=1e-6,
|
||||
prenorm=True):
|
||||
super().__init__()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
if mixer == 'Global' or mixer == 'Local':
|
||||
|
||||
self.mixer = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
mixer=mixer,
|
||||
HW=HW,
|
||||
local_k=local_mixer,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
elif mixer == 'Conv':
|
||||
self.mixer = ConvMixer(
|
||||
dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
|
||||
else:
|
||||
raise TypeError("The mixer must be one of [Global, Local, Conv]")
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
|
||||
if isinstance(norm_layer, str):
|
||||
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
|
||||
else:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp = Mlp(in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
self.prenorm = prenorm
|
||||
|
||||
def forward(self, x):
|
||||
if self.prenorm:
|
||||
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=(32, 100),
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
sub_num=2):
|
||||
super().__init__()
|
||||
num_patches = (img_size[1] // (2 ** sub_num)) * \
|
||||
(img_size[0] // (2 ** sub_num))
|
||||
self.img_size = img_size
|
||||
self.num_patches = num_patches
|
||||
self.embed_dim = embed_dim
|
||||
self.norm = None
|
||||
if sub_num == 2:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False))
|
||||
if sub_num == 3:
|
||||
self.proj = nn.Sequential(
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 4,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=False))
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class SubSample(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
types='Pool',
|
||||
stride=(2, 1),
|
||||
sub_norm='nn.LayerNorm',
|
||||
act=None):
|
||||
super().__init__()
|
||||
self.types = types
|
||||
if types == 'Pool':
|
||||
self.avgpool = nn.AvgPool2d(
|
||||
kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.maxpool = nn.MaxPool2d(
|
||||
kernel_size=(3, 5), stride=stride, padding=(1, 2))
|
||||
self.proj = nn.Linear(in_channels, out_channels)
|
||||
else:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
# weight_attr=ParamAttr(initializer=KaimingNormal())
|
||||
)
|
||||
self.norm = eval(sub_norm)(out_channels)
|
||||
if act is not None:
|
||||
self.act = act()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.types == 'Pool':
|
||||
x1 = self.avgpool(x)
|
||||
x2 = self.maxpool(x)
|
||||
x = (x1 + x2) * 0.5
|
||||
out = self.proj(x.flatten(2).permute((0, 2, 1)))
|
||||
else:
|
||||
x = self.conv(x)
|
||||
out = x.flatten(2).permute((0, 2, 1))
|
||||
out = self.norm(out)
|
||||
if self.act is not None:
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SVTRNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=[48, 100],
|
||||
in_channels=3,
|
||||
embed_dim=[64, 128, 256],
|
||||
depth=[3, 6, 3],
|
||||
num_heads=[2, 4, 8],
|
||||
mixer=['Local'] * 6 + ['Global'] *
|
||||
6, # Local atten, Global atten, Conv
|
||||
local_mixer=[[7, 11], [7, 11], [7, 11]],
|
||||
patch_merging='Conv', # Conv, Pool, None
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
last_drop=0.1,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer='nn.LayerNorm',
|
||||
sub_norm='nn.LayerNorm',
|
||||
epsilon=1e-6,
|
||||
out_channels=192,
|
||||
out_char_num=25,
|
||||
block_unit='Block',
|
||||
act='nn.GELU',
|
||||
last_stage=True,
|
||||
sub_num=2,
|
||||
prenorm=True,
|
||||
use_lenhead=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.embed_dim = embed_dim
|
||||
self.out_channels = out_channels
|
||||
self.prenorm = prenorm
|
||||
patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim[0],
|
||||
sub_num=sub_num)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
|
||||
# self.pos_embed = self.create_parameter(
|
||||
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
|
||||
|
||||
# self.add_parameter("pos_embed", self.pos_embed)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
Block_unit = eval(block_unit)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, sum(depth))
|
||||
self.blocks1 = nn.ModuleList(
|
||||
[
|
||||
Block_unit(
|
||||
dim=embed_dim[0],
|
||||
num_heads=num_heads[0],
|
||||
mixer=mixer[0:depth[0]][i],
|
||||
HW=self.HW,
|
||||
local_mixer=local_mixer[0],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[0:depth[0]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm) for i in range(depth[0])
|
||||
]
|
||||
)
|
||||
if patch_merging is not None:
|
||||
self.sub_sample1 = SubSample(
|
||||
embed_dim[0],
|
||||
embed_dim[1],
|
||||
sub_norm=sub_norm,
|
||||
stride=[2, 1],
|
||||
types=patch_merging)
|
||||
HW = [self.HW[0] // 2, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.patch_merging = patch_merging
|
||||
self.blocks2 = nn.ModuleList([
|
||||
Block_unit(
|
||||
dim=embed_dim[1],
|
||||
num_heads=num_heads[1],
|
||||
mixer=mixer[depth[0]:depth[0] + depth[1]][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[1],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm) for i in range(depth[1])
|
||||
])
|
||||
if patch_merging is not None:
|
||||
self.sub_sample2 = SubSample(
|
||||
embed_dim[1],
|
||||
embed_dim[2],
|
||||
sub_norm=sub_norm,
|
||||
stride=[2, 1],
|
||||
types=patch_merging)
|
||||
HW = [self.HW[0] // 4, self.HW[1]]
|
||||
else:
|
||||
HW = self.HW
|
||||
self.blocks3 = nn.ModuleList([
|
||||
Block_unit(
|
||||
dim=embed_dim[2],
|
||||
num_heads=num_heads[2],
|
||||
mixer=mixer[depth[0] + depth[1]:][i],
|
||||
HW=HW,
|
||||
local_mixer=local_mixer[2],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=eval(act),
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[depth[0] + depth[1]:][i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=epsilon,
|
||||
prenorm=prenorm) for i in range(depth[2])
|
||||
])
|
||||
self.last_stage = last_stage
|
||||
if last_stage:
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
|
||||
self.last_conv = nn.Conv2d(
|
||||
in_channels=embed_dim[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = nn.Dropout(p=last_drop)
|
||||
if not prenorm:
|
||||
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
|
||||
self.use_lenhead = use_lenhead
|
||||
if use_lenhead:
|
||||
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
|
||||
self.hardswish_len = nn.Hardswish()
|
||||
self.dropout_len = nn.Dropout(
|
||||
p=last_drop)
|
||||
|
||||
trunc_normal_(self.pos_embed,std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight,std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
for blk in self.blocks1:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample1(
|
||||
x.permute([0, 2, 1]).reshape(
|
||||
[-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
|
||||
for blk in self.blocks2:
|
||||
x = blk(x)
|
||||
if self.patch_merging is not None:
|
||||
x = self.sub_sample2(
|
||||
x.permute([0, 2, 1]).reshape(
|
||||
[-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
|
||||
for blk in self.blocks3:
|
||||
x = blk(x)
|
||||
if not self.prenorm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if self.use_lenhead:
|
||||
len_x = self.len_conv(x.mean(1))
|
||||
len_x = self.dropout_len(self.hardswish_len(len_x))
|
||||
if self.last_stage:
|
||||
if self.patch_merging is not None:
|
||||
h = self.HW[0] // 4
|
||||
else:
|
||||
h = self.HW[0]
|
||||
x = self.avg_pool(
|
||||
x.permute([0, 2, 1]).reshape(
|
||||
[-1, self.embed_dim[2], h, self.HW[1]]))
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
if self.use_lenhead:
|
||||
return x, len_x
|
||||
return x
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
a = torch.rand(1,3,48,100)
|
||||
svtr = SVTRNet()
|
||||
|
||||
out = svtr(a)
|
||||
print(svtr)
|
||||
print(out.size())
|
@ -1,74 +0,0 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Hswish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hswish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
|
||||
# out = max(0, min(1, slop*x+offset))
|
||||
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
|
||||
class Hsigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Hsigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
|
||||
return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(GELU, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
if self.inplace:
|
||||
x.mul_(torch.sigmoid(x))
|
||||
return x
|
||||
else:
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
class Activation(nn.Module):
|
||||
def __init__(self, act_type, inplace=True):
|
||||
super(Activation, self).__init__()
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
self.act = nn.ReLU(inplace=inplace)
|
||||
elif act_type == 'relu6':
|
||||
self.act = nn.ReLU6(inplace=inplace)
|
||||
elif act_type == 'sigmoid':
|
||||
raise NotImplementedError
|
||||
elif act_type == 'hard_sigmoid':
|
||||
self.act = Hsigmoid(inplace)
|
||||
elif act_type == 'hard_swish':
|
||||
self.act = Hswish(inplace=inplace)
|
||||
elif act_type == 'leakyrelu':
|
||||
self.act = nn.LeakyReLU(inplace=inplace)
|
||||
elif act_type == 'gelu':
|
||||
self.act = GELU(inplace=inplace)
|
||||
elif act_type == 'swish':
|
||||
self.act = Swish(inplace=inplace)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.act(inputs)
|
@ -1,95 +0,0 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,151 +0,0 @@
|
||||
import os
|
||||
import datetime
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def save_images(img_list, folder):
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
now = datetime.datetime.now()
|
||||
date_str = now.strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join(folder, date_str)
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path)
|
||||
time_str = now.strftime("%H_%M_%S")
|
||||
for idx, img in enumerate(img_list):
|
||||
image_number = idx + 1
|
||||
filename = f"{time_str}_{image_number}.jpg"
|
||||
save_path = os.path.join(folder_path, filename)
|
||||
cv2.imwrite(save_path, img[..., ::-1])
|
||||
|
||||
|
||||
def check_channels(image):
|
||||
channels = image.shape[2] if len(image.shape) == 3 else 1
|
||||
if channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
elif channels > 3:
|
||||
image = image[:, :, :3]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(img, max_length=768):
|
||||
height, width = img.shape[:2]
|
||||
max_dimension = max(height, width)
|
||||
|
||||
if max_dimension > max_length:
|
||||
scale_factor = max_length / max_dimension
|
||||
new_width = int(round(width * scale_factor))
|
||||
new_height = int(round(height * scale_factor))
|
||||
new_size = (new_width, new_height)
|
||||
img = cv2.resize(img, new_size)
|
||||
height, width = img.shape[:2]
|
||||
img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
|
||||
return img
|
||||
|
||||
|
||||
def insert_spaces(string, nSpace):
|
||||
if nSpace == 0:
|
||||
return string
|
||||
new_string = ""
|
||||
for char in string:
|
||||
new_string += char + " " * nSpace
|
||||
return new_string[:-nSpace]
|
||||
|
||||
|
||||
def draw_glyph(font, text):
|
||||
g_size = 50
|
||||
W, H = (512, 80)
|
||||
new_font = font.font_variant(size=g_size)
|
||||
img = Image.new(mode="1", size=(W, H), color=0)
|
||||
draw = ImageDraw.Draw(img)
|
||||
left, top, right, bottom = new_font.getbbox(text)
|
||||
text_width = max(right - left, 5)
|
||||
text_height = max(bottom - top, 5)
|
||||
ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
|
||||
new_font = font.font_variant(size=int(g_size * ratio))
|
||||
|
||||
text_width, text_height = new_font.getsize(text)
|
||||
offset_x, offset_y = new_font.getoffset(text)
|
||||
x = (img.width - text_width) // 2
|
||||
y = (img.height - text_height) // 2 - offset_y // 2
|
||||
draw.text((x, y), text, font=new_font, fill="white")
|
||||
img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
|
||||
return img
|
||||
|
||||
|
||||
def draw_glyph2(
|
||||
font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True
|
||||
):
|
||||
enlarge_polygon = polygon * scale
|
||||
rect = cv2.minAreaRect(enlarge_polygon)
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
w, h = rect[1]
|
||||
angle = rect[2]
|
||||
if angle < -45:
|
||||
angle += 90
|
||||
angle = -angle
|
||||
if w < h:
|
||||
angle += 90
|
||||
|
||||
vert = False
|
||||
if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
|
||||
_w = max(box[:, 0]) - min(box[:, 0])
|
||||
_h = max(box[:, 1]) - min(box[:, 1])
|
||||
if _h >= _w:
|
||||
vert = True
|
||||
angle = 0
|
||||
|
||||
img = np.zeros((height * scale, width * scale, 3), np.uint8)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
# infer font size
|
||||
image4ratio = Image.new("RGB", img.size, "white")
|
||||
draw = ImageDraw.Draw(image4ratio)
|
||||
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
|
||||
text_w = min(w, h) * (_tw / _th)
|
||||
if text_w <= max(w, h):
|
||||
# add space
|
||||
if len(text) > 1 and not vert and add_space:
|
||||
for i in range(1, 100):
|
||||
text_space = insert_spaces(text, i)
|
||||
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
|
||||
if min(w, h) * (_tw2 / _th2) > max(w, h):
|
||||
break
|
||||
text = insert_spaces(text, i - 1)
|
||||
font_size = min(w, h) * 0.80
|
||||
else:
|
||||
shrink = 0.75 if vert else 0.85
|
||||
font_size = min(w, h) / (text_w / max(w, h)) * shrink
|
||||
new_font = font.font_variant(size=int(font_size))
|
||||
|
||||
left, top, right, bottom = new_font.getbbox(text)
|
||||
text_width = right - left
|
||||
text_height = bottom - top
|
||||
|
||||
layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(layer)
|
||||
if not vert:
|
||||
draw.text(
|
||||
(rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
|
||||
text,
|
||||
font=new_font,
|
||||
fill=(255, 255, 255, 255),
|
||||
)
|
||||
else:
|
||||
x_s = min(box[:, 0]) + _w // 2 - text_height // 2
|
||||
y_s = min(box[:, 1])
|
||||
for c in text:
|
||||
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
|
||||
_, _t, _, _b = new_font.getbbox(c)
|
||||
y_s += _b
|
||||
|
||||
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
|
||||
|
||||
x_offset = int((img.width - rotated_layer.width) / 2)
|
||||
y_offset = int((img.height - rotated_layer.height) / 2)
|
||||
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
|
||||
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
|
||||
return img
|
@ -1,405 +0,0 @@
|
||||
import abc
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.helper import (
|
||||
boxes_from_mask,
|
||||
resize_max_size,
|
||||
pad_img_to_modulo,
|
||||
switch_mps_device,
|
||||
)
|
||||
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
|
||||
from .helper.g_diffuser_bot import expand_image
|
||||
from .utils import get_scheduler
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
name = "base"
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
is_erase_model = False
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
device = switch_mps_device(self.name, device)
|
||||
self.device = device
|
||||
self.init_model(device, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device, **kwargs): ...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool:
|
||||
return False
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W, 1] 255 为 masks 区域
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def download(): ...
|
||||
|
||||
def _pad_forward(self, image, mask, config: InpaintRequest):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
pad_image = pad_img_to_modulo(
|
||||
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
||||
)
|
||||
pad_mask = pad_img_to_modulo(
|
||||
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
||||
)
|
||||
|
||||
# logger.info(f"final forward pad size: {pad_image.shape}")
|
||||
|
||||
image, mask = self.forward_pre_process(image, mask, config)
|
||||
|
||||
result = self.forward(pad_image, pad_mask, config)
|
||||
result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
result, image, mask = self.forward_post_process(result, image, mask, config)
|
||||
|
||||
if config.sd_keep_unmasked_area:
|
||||
mask = mask[:, :, np.newaxis]
|
||||
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
||||
return result
|
||||
|
||||
def forward_pre_process(self, image, mask, config):
|
||||
return image, mask
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
return result, image, mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: InpaintRequest):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
inpaint_result = None
|
||||
# logger.info(f"hd_strategy: {config.hd_strategy}")
|
||||
if config.hd_strategy == HDStrategy.CROP:
|
||||
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
||||
logger.info("Run crop strategy")
|
||||
boxes = boxes_from_mask(mask)
|
||||
crop_result = []
|
||||
for box in boxes:
|
||||
crop_image, crop_box = self._run_box(image, mask, box, config)
|
||||
crop_result.append((crop_image, crop_box))
|
||||
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
for crop_image, crop_box in crop_result:
|
||||
x1, y1, x2, y2 = crop_box
|
||||
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
||||
|
||||
elif config.hd_strategy == HDStrategy.RESIZE:
|
||||
if max(image.shape) > config.hd_strategy_resize_limit:
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(
|
||||
image, size_limit=config.hd_strategy_resize_limit
|
||||
)
|
||||
downsize_mask = resize_max_size(
|
||||
mask, size_limit=config.hd_strategy_resize_limit
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
||||
)
|
||||
inpaint_result = self._pad_forward(
|
||||
downsize_image, downsize_mask, config
|
||||
)
|
||||
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(
|
||||
inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
original_pixel_indices = mask < 127
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
||||
original_pixel_indices
|
||||
]
|
||||
|
||||
if inpaint_result is None:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def _crop_box(self, image, mask, box, config: InpaintRequest):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
box: [left,top,right,bottom]
|
||||
|
||||
Returns:
|
||||
BGR IMAGE, (l, r, r, b)
|
||||
"""
|
||||
box_h = box[3] - box[1]
|
||||
box_w = box[2] - box[0]
|
||||
cx = (box[0] + box[2]) // 2
|
||||
cy = (box[1] + box[3]) // 2
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
w = box_w + config.hd_strategy_crop_margin * 2
|
||||
h = box_h + config.hd_strategy_crop_margin * 2
|
||||
|
||||
_l = cx - w // 2
|
||||
_r = cx + w // 2
|
||||
_t = cy - h // 2
|
||||
_b = cy + h // 2
|
||||
|
||||
l = max(_l, 0)
|
||||
r = min(_r, img_w)
|
||||
t = max(_t, 0)
|
||||
b = min(_b, img_h)
|
||||
|
||||
# try to get more context when crop around image edge
|
||||
if _l < 0:
|
||||
r += abs(_l)
|
||||
if _r > img_w:
|
||||
l -= _r - img_w
|
||||
if _t < 0:
|
||||
b += abs(_t)
|
||||
if _b > img_h:
|
||||
t -= _b - img_h
|
||||
|
||||
l = max(l, 0)
|
||||
r = min(r, img_w)
|
||||
t = max(t, 0)
|
||||
b = min(b, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
|
||||
# logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||
|
||||
return crop_img, crop_mask, [l, t, r, b]
|
||||
|
||||
def _calculate_cdf(self, histogram):
|
||||
cdf = histogram.cumsum()
|
||||
normalized_cdf = cdf / float(cdf.max())
|
||||
return normalized_cdf
|
||||
|
||||
def _calculate_lookup(self, source_cdf, reference_cdf):
|
||||
lookup_table = np.zeros(256)
|
||||
lookup_val = 0
|
||||
for source_index, source_val in enumerate(source_cdf):
|
||||
for reference_index, reference_val in enumerate(reference_cdf):
|
||||
if reference_val >= source_val:
|
||||
lookup_val = reference_index
|
||||
break
|
||||
lookup_table[source_index] = lookup_val
|
||||
return lookup_table
|
||||
|
||||
def _match_histograms(self, source, reference, mask):
|
||||
transformed_channels = []
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask[:, :, -1]
|
||||
|
||||
for channel in range(source.shape[-1]):
|
||||
source_channel = source[:, :, channel]
|
||||
reference_channel = reference[:, :, channel]
|
||||
|
||||
# only calculate histograms for non-masked parts
|
||||
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
||||
reference_histogram, _ = np.histogram(
|
||||
reference_channel[mask == 0], 256, [0, 256]
|
||||
)
|
||||
|
||||
source_cdf = self._calculate_cdf(source_histogram)
|
||||
reference_cdf = self._calculate_cdf(reference_histogram)
|
||||
|
||||
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
||||
|
||||
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
||||
|
||||
result = cv2.merge(transformed_channels)
|
||||
result = cv2.convertScaleAbs(result)
|
||||
|
||||
return result
|
||||
|
||||
def _apply_cropper(self, image, mask, config: InpaintRequest):
|
||||
img_h, img_w = image.shape[:2]
|
||||
l, t, w, h = (
|
||||
config.croper_x,
|
||||
config.croper_y,
|
||||
config.croper_width,
|
||||
config.croper_height,
|
||||
)
|
||||
r = l + w
|
||||
b = t + h
|
||||
|
||||
l = max(l, 0)
|
||||
r = min(r, img_w)
|
||||
t = max(t, 0)
|
||||
b = min(b, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
return crop_img, crop_mask, (l, t, r, b)
|
||||
|
||||
def _run_box(self, image, mask, box, config: InpaintRequest):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
box: [left,top,right,bottom]
|
||||
|
||||
Returns:
|
||||
BGR IMAGE
|
||||
"""
|
||||
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
||||
|
||||
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
||||
|
||||
|
||||
class DiffusionInpaintModel(InpaintModel):
|
||||
def __init__(self, device, **kwargs):
|
||||
self.model_info = kwargs["model_info"]
|
||||
self.model_id_or_path = self.model_info.path
|
||||
super().__init__(device, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: InpaintRequest):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# boxes = boxes_from_mask(mask)
|
||||
if config.use_croper:
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
elif config.use_extender:
|
||||
inpaint_result = self._do_outpainting(image, config)
|
||||
else:
|
||||
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def _do_outpainting(self, image, config: InpaintRequest):
|
||||
# cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
|
||||
# 从 image 中 crop 出 outpainting 区域
|
||||
image_h, image_w = image.shape[:2]
|
||||
cropper_l = config.extender_x
|
||||
cropper_t = config.extender_y
|
||||
cropper_r = config.extender_x + config.extender_width
|
||||
cropper_b = config.extender_y + config.extender_height
|
||||
image_l = 0
|
||||
image_t = 0
|
||||
image_r = image_w
|
||||
image_b = image_h
|
||||
|
||||
# 类似求 IOU
|
||||
l = max(cropper_l, image_l)
|
||||
t = max(cropper_t, image_t)
|
||||
r = min(cropper_r, image_r)
|
||||
b = min(cropper_b, image_b)
|
||||
|
||||
assert (
|
||||
0 <= l < r and 0 <= t < b
|
||||
), f"cropper and image not overlap, {l},{t},{r},{b}"
|
||||
|
||||
cropped_image = image[t:b, l:r, :]
|
||||
padding_l = max(0, image_l - cropper_l)
|
||||
padding_t = max(0, image_t - cropper_t)
|
||||
padding_r = max(0, cropper_r - image_r)
|
||||
padding_b = max(0, cropper_b - image_b)
|
||||
|
||||
expanded_image, mask_image = expand_image(
|
||||
cropped_image,
|
||||
left=padding_l,
|
||||
top=padding_t,
|
||||
right=padding_r,
|
||||
bottom=padding_b,
|
||||
)
|
||||
|
||||
# 最终扩大了的 image, BGR
|
||||
expanded_cropped_result_image = self._scaled_pad_forward(
|
||||
expanded_image, mask_image, config
|
||||
)
|
||||
|
||||
# RGB -> BGR
|
||||
outpainting_image = cv2.copyMakeBorder(
|
||||
image,
|
||||
left=padding_l,
|
||||
top=padding_t,
|
||||
right=padding_r,
|
||||
bottom=padding_b,
|
||||
borderType=cv2.BORDER_CONSTANT,
|
||||
value=0,
|
||||
)[:, :, ::-1]
|
||||
|
||||
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
|
||||
paste_t = 0 if config.extender_y < 0 else config.extender_y
|
||||
paste_l = 0 if config.extender_x < 0 else config.extender_x
|
||||
|
||||
outpainting_image[
|
||||
paste_t : paste_t + expanded_cropped_result_image.shape[0],
|
||||
paste_l : paste_l + expanded_cropped_result_image.shape[1],
|
||||
:,
|
||||
] = expanded_cropped_result_image
|
||||
return outpainting_image
|
||||
|
||||
def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
|
||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||
if config.sd_scale != 1:
|
||||
logger.info(
|
||||
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
||||
)
|
||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(
|
||||
inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def set_scheduler(self, config: InpaintRequest):
|
||||
scheduler_config = self.model.scheduler.config
|
||||
sd_sampler = config.sd_sampler
|
||||
if config.sd_lcm_lora and self.model_info.support_lcm_lora:
|
||||
sd_sampler = SDSampler.lcm
|
||||
logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
|
||||
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
def forward_pre_process(self, image, mask, config):
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
|
||||
return image, mask
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
if config.sd_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
if config.use_extender and config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
return result, image, mask
|
@ -1,931 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.models.attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||
TimestepEmbedding, Timesteps
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D, get_down_block, get_up_block,
|
||||
)
|
||||
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_2d_blocks import MidBlock2D
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrushNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`BrushNetModel`].
|
||||
|
||||
Args:
|
||||
up_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's upsampling activations.
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
up_block_res_samples: Tuple[torch.Tensor]
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A BrushNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 5,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in_condition = nn.Conv2d(
|
||||
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.brushnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_mid_block = brushnet_block
|
||||
|
||||
self.mid_block = MidBlock2D(
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0.0,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.brushnet_up_blocks = nn.ModuleList([])
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resolution_idx=i,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
for _ in range(layers_per_block + 1):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
brushnet = cls(
|
||||
in_channels=unet.config.in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D'],
|
||||
mid_block_type='MidBlock2D',
|
||||
up_block_types=['UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'],
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||
|
||||
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if brushnet.class_embedding:
|
||||
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
|
||||
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||
|
||||
return brushnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
brushnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`BrushNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
brushnet_cond (`torch.FloatTensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for BrushNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.brushnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
||||
sample = self.conv_in_condition(brushnet_cond)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. PaintingNet down blocks
|
||||
brushnet_down_block_res_samples = ()
|
||||
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
# 5. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 6. BrushNet mid blocks
|
||||
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
||||
|
||||
# 7. up
|
||||
up_block_res_samples = ()
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
return_res_samples=True
|
||||
)
|
||||
else:
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
return_res_samples=True
|
||||
)
|
||||
|
||||
up_block_res_samples += up_res_samples
|
||||
|
||||
# 8. BrushNet up blocks
|
||||
brushnet_up_block_res_samples = ()
|
||||
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0,
|
||||
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
|
||||
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||
scales[:len(
|
||||
brushnet_down_block_res_samples)])]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||
scales[
|
||||
len(brushnet_down_block_res_samples) + 1:])]
|
||||
else:
|
||||
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||
brushnet_down_block_res_samples]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
brushnet_down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
brushnet_up_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||
|
||||
return BrushNetOutput(
|
||||
down_block_res_samples=brushnet_down_block_res_samples,
|
||||
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||
up_block_res_samples=brushnet_up_block_res_samples
|
||||
)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
BrushNetModel.from_pretrained("/Users/cwq/data/models/brushnet/brushnet_random_mask", variant='fp16',
|
||||
use_safetensors=True)
|
@ -1,322 +0,0 @@
|
||||
from typing import Union, Optional, Dict, Any, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers
|
||||
|
||||
|
||||
def brushnet_unet_forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DConditionModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||
example from ControlNet side model(s)
|
||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2 ** self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
for dim in sample.shape[-2:]:
|
||||
if dim % default_overall_up_factor != 0:
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
break
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
||||
if class_emb is not None:
|
||||
if self.config.class_embeddings_concat:
|
||||
emb = torch.cat([emb, class_emb], dim=-1)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
aug_emb = self.get_aug_embed(
|
||||
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
if self.config.addition_embed_type == "image_hint":
|
||||
aug_emb, hint = aug_emb
|
||||
sample = torch.cat([sample, hint], dim=1)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
encoder_hidden_states = self.process_encoder_hidden_states(
|
||||
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 2.5 GLIGEN position net
|
||||
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
gligen_args = cross_attention_kwargs.pop("gligen")
|
||||
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||
|
||||
# 3. down
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
# maintain backward compatibility for legacy usage, where
|
||||
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||
# but can only use one or the other
|
||||
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
||||
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||
deprecate(
|
||||
"T2I should not use down_block_additional_residuals",
|
||||
"1.3.0",
|
||||
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||
standard_warn=False,
|
||||
)
|
||||
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||
is_adapter = True
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
|
||||
if is_brushnet:
|
||||
sample = sample + down_block_add_samples.pop(0)
|
||||
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
# For t2i-adapter CrossAttnDownBlock2D
|
||||
additional_residuals = {}
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_brushnet and len(down_block_add_samples) > 0:
|
||||
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(down_block_add_samples) > 0:
|
||||
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale,
|
||||
**additional_residuals)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if is_controlnet:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals
|
||||
):
|
||||
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
||||
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_intrablock_additional_residuals) > 0
|
||||
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
if is_brushnet:
|
||||
sample = sample + mid_block_add_sample
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(up_block_add_samples) > 0:
|
||||
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(up_block_add_samples) > 0:
|
||||
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
scale=lora_scale,
|
||||
**additional_residuals,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
if self.conv_norm_out:
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
@ -1,157 +0,0 @@
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
|
||||
from ..base import DiffusionInpaintModel
|
||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from ..original_sd_configs import get_config_files
|
||||
from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
from .brushnet import BrushNetModel
|
||||
from .brushnet_unet_forward import brushnet_unet_forward
|
||||
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
|
||||
UpBlock2D_forward
|
||||
from ...schema import InpaintRequest, ModelType
|
||||
|
||||
|
||||
class BrushNetWrapper(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
||||
self.model_info = kwargs["model_info"]
|
||||
self.brushnet_method = kwargs["brushnet_method"]
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
model_kwargs = {
|
||||
**kwargs.get("pipe_components", {}),
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
self.local_files_only = model_kwargs["local_files_only"]
|
||||
|
||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||
"cpu_offload", False
|
||||
)
|
||||
if disable_nsfw_checker:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
||||
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
self.model = StableDiffusionBrushNetPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=not disable_nsfw_checker,
|
||||
original_config_file=get_config_files()['v1'],
|
||||
brushnet=brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
brushnet=brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
|
||||
|
||||
for down_block in self.model.brushnet.down_blocks:
|
||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
for up_block in self.model.brushnet.up_blocks:
|
||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in self.model.unet.down_blocks:
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
|
||||
for up_block in self.model.unet.up_blocks:
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
|
||||
def switch_brushnet_method(self, new_method: str):
|
||||
self.brushnet_method = new_method
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
new_method,
|
||||
resume_download=True,
|
||||
local_files_only=self.local_files_only,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(self.model.device)
|
||||
self.model.brushnet = brushnet
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
normalized_mask = mask[:, :].astype("float32") / 255.0
|
||||
image = image * (1 - normalized_mask)
|
||||
image = image.astype(np.uint8)
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
|
||||
num_inference_steps=config.sd_steps,
|
||||
# strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
brushnet_conditioning_scale=config.brushnet_conditioning_scale,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
File diff suppressed because it is too large
Load Diff
@ -1,388 +0,0 @@
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.resnet import ResnetBlock2D
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.utils.torch_utils import apply_freeu
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MidBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor: float = 1.0,
|
||||
use_linear_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = False
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
for resnet in self.resnets[1:]:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def DownBlock2D_forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0,
|
||||
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def CrossAttnDownBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
hidden_states = hidden_states + additional_residuals
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def CrossAttnUpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
return_res_samples: Optional[bool] = False,
|
||||
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
if return_res_samples:
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||
|
||||
if return_res_samples:
|
||||
return hidden_states, output_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def UpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
return_res_samples: Optional[bool] = False,
|
||||
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
if return_res_samples:
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
if return_res_samples:
|
||||
return hidden_states, output_states
|
||||
else:
|
||||
return hidden_states
|
@ -1,194 +0,0 @@
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from loguru import logger
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
|
||||
from .base import DiffusionInpaintModel
|
||||
from .helper.controlnet_preprocess import (
|
||||
make_canny_control_image,
|
||||
make_openpose_control_image,
|
||||
make_depth_control_image,
|
||||
make_inpaint_control_image,
|
||||
)
|
||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from .original_sd_configs import get_config_files
|
||||
from .utils import (
|
||||
get_scheduler,
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
name = "controlnet"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
@property
|
||||
def lcm_lora_id(self):
|
||||
if self.model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
]:
|
||||
return "latent-consistency/lcm-lora-sdv1-5"
|
||||
if self.model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
return "latent-consistency/lcm-lora-sdxl"
|
||||
raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
model_info = kwargs["model_info"]
|
||||
controlnet_method = kwargs["controlnet_method"]
|
||||
|
||||
self.model_info = model_info
|
||||
self.controlnet_method = controlnet_method
|
||||
|
||||
model_kwargs = {
|
||||
**kwargs.get("pipe_components", {}),
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
self.local_files_only = model_kwargs["local_files_only"]
|
||||
|
||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||
"cpu_offload", False
|
||||
)
|
||||
if disable_nsfw_checker:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
original_config_file_name = "v1"
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
]:
|
||||
from diffusers import (
|
||||
StableDiffusionControlNetInpaintPipeline as PipeClass,
|
||||
)
|
||||
original_config_file_name = "v1"
|
||||
|
||||
elif model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
from diffusers import (
|
||||
StableDiffusionXLControlNetInpaintPipeline as PipeClass,
|
||||
)
|
||||
original_config_file_name = "xl"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
pretrained_model_name_or_path=controlnet_method,
|
||||
resume_download=True,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
if model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
self.model = PipeClass.from_single_file(
|
||||
model_info.path,
|
||||
controlnet=controlnet,
|
||||
load_safety_checker=not disable_nsfw_checker,
|
||||
torch_dtype=torch_dtype,
|
||||
original_config_file=get_config_files()[original_config_file_name],
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
PipeClass.from_pretrained,
|
||||
pretrained_model_name_or_path=model_info.path,
|
||||
controlnet=controlnet,
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def switch_controlnet_method(self, new_method: str):
|
||||
self.controlnet_method = new_method
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
new_method,
|
||||
resume_download=True,
|
||||
local_files_only=self.local_files_only,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(self.model.device)
|
||||
self.model.controlnet = controlnet
|
||||
|
||||
def _get_control_image(self, image, mask):
|
||||
if "canny" in self.controlnet_method:
|
||||
control_image = make_canny_control_image(image)
|
||||
elif "openpose" in self.controlnet_method:
|
||||
control_image = make_openpose_control_image(image)
|
||||
elif "depth" in self.controlnet_method:
|
||||
control_image = make_depth_control_image(image)
|
||||
elif "inpaint" in self.controlnet_method:
|
||||
control_image = make_inpaint_control_image(image, mask)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.controlnet_method} not implemented")
|
||||
return control_image
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
scheduler_config = self.model.scheduler.config
|
||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
control_image = self._get_control_image(image, mask)
|
||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||
image = PIL.Image.fromarray(image)
|
||||
|
||||
output = self.model(
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
control_image=control_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
@ -1,193 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear"):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
# array([1])
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, steps, conditioning, batch_size, shape):
|
||||
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
# samples: 1,3,128,128
|
||||
return self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
quantize_denoised=False,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=0,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
ddim_use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
)
|
||||
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
)
|
||||
img, _ = outs
|
||||
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised: # 没用
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0: # 没用
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
1737
iopaint/model/fcf.py
1737
iopaint/model/fcf.py
File diff suppressed because it is too large
Load Diff
@ -1,68 +0,0 @@
|
||||
import torch
|
||||
import PIL
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from iopaint.helper import pad_img_to_modulo
|
||||
|
||||
|
||||
def make_canny_control_image(image: np.ndarray) -> Image:
|
||||
canny_image = cv2.Canny(image, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
||||
canny_image = PIL.Image.fromarray(canny_image)
|
||||
control_image = canny_image
|
||||
return control_image
|
||||
|
||||
|
||||
def make_openpose_control_image(image: np.ndarray) -> Image:
|
||||
from controlnet_aux import OpenposeDetector
|
||||
|
||||
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
||||
control_image = processor(image, hand_and_face=True)
|
||||
return control_image
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(
|
||||
input_image,
|
||||
(W, H),
|
||||
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def make_depth_control_image(image: np.ndarray) -> Image:
|
||||
from controlnet_aux import MidasDetector
|
||||
|
||||
midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
pad_image = pad_img_to_modulo(image, mod=64, square=False, min_size=512)
|
||||
depth_image = midas(pad_image)
|
||||
depth_image = depth_image[0:origin_height, 0:origin_width]
|
||||
depth_image = depth_image[:, :, None]
|
||||
depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2)
|
||||
control_image = PIL.Image.fromarray(depth_image)
|
||||
return control_image
|
||||
|
||||
|
||||
def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
"""
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel
|
||||
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image
|
@ -1,41 +0,0 @@
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..utils import torch_gc
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper(PreTrainedModel):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__(text_encoder.config)
|
||||
self.config = text_encoder.config
|
||||
self._device = text_encoder.device
|
||||
# cpu not support float16
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
self.torch_dtype = torch_dtype
|
||||
del text_encoder
|
||||
torch_gc()
|
||||
|
||||
def __call__(self, x, **kwargs):
|
||||
input_device = x.device
|
||||
original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
|
||||
for k, v in original_output.items():
|
||||
if isinstance(v, tuple):
|
||||
original_output[k] = [
|
||||
v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
|
||||
]
|
||||
else:
|
||||
original_output[k] = v.to(input_device).to(self.torch_dtype)
|
||||
return original_output
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.torch_dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||
device).
|
||||
"""
|
||||
return self._device
|
@ -1,62 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def expand_image(cv2_img, top: int, right: int, bottom: int, left: int):
|
||||
assert cv2_img.shape[2] == 3
|
||||
origin_h, origin_w = cv2_img.shape[:2]
|
||||
|
||||
# TODO: which is better?
|
||||
# new_img = np.ones((new_height, new_width, 3), np.uint8) * 255
|
||||
new_img = cv2.copyMakeBorder(
|
||||
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
|
||||
)
|
||||
|
||||
inner_padding_left = 0 if left > 0 else 0
|
||||
inner_padding_right = 0 if right > 0 else 0
|
||||
inner_padding_top = 0 if top > 0 else 0
|
||||
inner_padding_bottom = 0 if bottom > 0 else 0
|
||||
|
||||
mask_image = np.zeros(
|
||||
(
|
||||
origin_h - inner_padding_top - inner_padding_bottom,
|
||||
origin_w - inner_padding_left - inner_padding_right,
|
||||
),
|
||||
np.uint8,
|
||||
)
|
||||
mask_image = cv2.copyMakeBorder(
|
||||
mask_image,
|
||||
top + inner_padding_top,
|
||||
bottom + inner_padding_bottom,
|
||||
left + inner_padding_left,
|
||||
right + inner_padding_right,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=255,
|
||||
)
|
||||
# k = 2*int(min(origin_h, origin_w) // 6)+1
|
||||
# k = 7
|
||||
# mask_image = cv2.GaussianBlur(mask_image, (k, k), 0)
|
||||
return new_img, mask_image
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pathlib import Path
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg"
|
||||
init_image = cv2.imread(str(image_path))
|
||||
init_image, mask_image = expand_image(
|
||||
init_image,
|
||||
top=0,
|
||||
right=0,
|
||||
bottom=0,
|
||||
left=100,
|
||||
softness=20,
|
||||
space=20,
|
||||
)
|
||||
print(mask_image.dtype, mask_image.min(), mask_image.max())
|
||||
print(init_image.dtype, init_image.min(), init_image.max())
|
||||
mask_image = mask_image.astype(np.uint8)
|
||||
init_image = init_image.astype(np.uint8)
|
||||
cv2.imwrite("expanded_image.png", init_image)
|
||||
cv2.imwrite("expanded_mask.png", mask_image)
|
@ -1,64 +0,0 @@
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.const import INSTRUCT_PIX2PIX_NAME
|
||||
from .base import DiffusionInpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||
|
||||
|
||||
class InstructPix2Pix(DiffusionInpaintModel):
|
||||
name = INSTRUCT_PIX2PIX_NAME
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
|
||||
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
|
||||
"""
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
image_guidance_scale=config.p2p_image_guidance_scale,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
@ -1,65 +0,0 @@
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from iopaint.const import KANDINSKY22_NAME
|
||||
from .base import DiffusionInpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||
|
||||
|
||||
class Kandinsky(DiffusionInpaintModel):
|
||||
pad_mod = 64
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch_dtype,
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
self.name, **model_kwargs
|
||||
).to(device)
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
generator = torch.manual_seed(config.sd_seed)
|
||||
mask = mask.astype(np.float32) / 255
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
# kandinsky 没有 strength
|
||||
output = self.model(
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
image=PIL.Image.fromarray(image),
|
||||
mask_image=mask[:, :, 0],
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback_on_step_end=self.callback,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
|
||||
class Kandinsky22(Kandinsky):
|
||||
name = KANDINSKY22_NAME
|
@ -1,57 +0,0 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from iopaint.helper import (
|
||||
norm_img,
|
||||
get_cache_path_by_url,
|
||||
load_jit_model,
|
||||
download_model,
|
||||
)
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .base import InpaintModel
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
name = "lama"
|
||||
pad_mod = 8
|
||||
is_erase_model = True
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
download_model(LAMA_MODEL_URL, LAMA_MODEL_MD5)
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||
return cur_res
|
@ -1,336 +0,0 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .base import InpaintModel
|
||||
from .ddim_sampler import DDIMSampler
|
||||
from .plms_sampler import PLMSSampler
|
||||
from iopaint.schema import InpaintRequest, LDMSampler
|
||||
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from iopaint.helper import (
|
||||
download_model,
|
||||
norm_img,
|
||||
get_cache_path_by_url,
|
||||
load_jit_model,
|
||||
)
|
||||
from .utils import (
|
||||
make_beta_schedule,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||
"LDM_ENCODE_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
||||
)
|
||||
LDM_ENCODE_MODEL_MD5 = os.environ.get(
|
||||
"LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
|
||||
)
|
||||
|
||||
LDM_DECODE_MODEL_URL = os.environ.get(
|
||||
"LDM_DECODE_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
||||
)
|
||||
LDM_DECODE_MODEL_MD5 = os.environ.get(
|
||||
"LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
|
||||
)
|
||||
|
||||
LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
||||
"LDM_DIFFUSION_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
||||
)
|
||||
|
||||
LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
|
||||
"LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
|
||||
)
|
||||
|
||||
|
||||
class DDPM(nn.Module):
|
||||
# classic DDPM with Gaussian diffusion, in image space
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
timesteps=1000,
|
||||
beta_schedule="linear",
|
||||
linear_start=0.0015,
|
||||
linear_end=0.0205,
|
||||
cosine_s=0.008,
|
||||
original_elbo_weight=0.0,
|
||||
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
||||
l_simple_weight=1.0,
|
||||
parameterization="eps", # all assuming fixed variance schedules
|
||||
use_positional_encodings=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.parameterization = parameterization
|
||||
self.use_positional_encodings = use_positional_encodings
|
||||
|
||||
self.v_posterior = v_posterior
|
||||
self.original_elbo_weight = original_elbo_weight
|
||||
self.l_simple_weight = l_simple_weight
|
||||
|
||||
self.register_schedule(
|
||||
beta_schedule=beta_schedule,
|
||||
timesteps=timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s,
|
||||
)
|
||||
|
||||
def register_schedule(
|
||||
self,
|
||||
given_betas=None,
|
||||
beta_schedule="linear",
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3,
|
||||
):
|
||||
betas = make_beta_schedule(
|
||||
self.device,
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s,
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
|
||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
||||
)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = (1 - self.v_posterior) * betas * (
|
||||
1.0 - alphas_cumprod_prev
|
||||
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer(
|
||||
"posterior_log_variance_clipped",
|
||||
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
||||
)
|
||||
self.register_buffer(
|
||||
"posterior_mean_coef1",
|
||||
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
||||
)
|
||||
self.register_buffer(
|
||||
"posterior_mean_coef2",
|
||||
to_torch(
|
||||
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
),
|
||||
)
|
||||
|
||||
if self.parameterization == "eps":
|
||||
lvlb_weights = self.betas**2 / (
|
||||
2
|
||||
* self.posterior_variance
|
||||
* to_torch(alphas)
|
||||
* (1 - self.alphas_cumprod)
|
||||
)
|
||||
elif self.parameterization == "x0":
|
||||
lvlb_weights = (
|
||||
0.5
|
||||
* np.sqrt(torch.Tensor(alphas_cumprod))
|
||||
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("mu not supported")
|
||||
# TODO how to choose this term
|
||||
lvlb_weights[0] = lvlb_weights[1]
|
||||
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
||||
assert not torch.isnan(self.lvlb_weights).all()
|
||||
|
||||
|
||||
class LatentDiffusion(DDPM):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
device,
|
||||
cond_stage_key="image",
|
||||
cond_stage_trainable=False,
|
||||
concat_mode=True,
|
||||
scale_factor=1.0,
|
||||
scale_by_std=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_timesteps_cond = 1
|
||||
self.scale_by_std = scale_by_std
|
||||
super().__init__(device, *args, **kwargs)
|
||||
self.diffusion_model = diffusion_model
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.num_downs = 2
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def make_cond_schedule(
|
||||
self,
|
||||
):
|
||||
self.cond_ids = torch.full(
|
||||
size=(self.num_timesteps,),
|
||||
fill_value=self.num_timesteps - 1,
|
||||
dtype=torch.long,
|
||||
)
|
||||
ids = torch.round(
|
||||
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
||||
).long()
|
||||
self.cond_ids[: self.num_timesteps_cond] = ids
|
||||
|
||||
def register_schedule(
|
||||
self,
|
||||
given_betas=None,
|
||||
beta_schedule="linear",
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3,
|
||||
):
|
||||
super().register_schedule(
|
||||
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
||||
)
|
||||
|
||||
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
||||
if self.shorten_cond_schedule:
|
||||
self.make_cond_schedule()
|
||||
|
||||
def apply_model(self, x_noisy, t, cond):
|
||||
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
|
||||
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
|
||||
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
|
||||
return x_recon
|
||||
|
||||
|
||||
class LDM(InpaintModel):
|
||||
name = "ldm"
|
||||
pad_mod = 32
|
||||
is_erase_model = True
|
||||
|
||||
def __init__(self, device, fp16: bool = True, **kwargs):
|
||||
self.fp16 = fp16
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.diffusion_model = load_jit_model(
|
||||
LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
|
||||
)
|
||||
self.cond_stage_model_decode = load_jit_model(
|
||||
LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
|
||||
)
|
||||
self.cond_stage_model_encode = load_jit_model(
|
||||
LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
|
||||
)
|
||||
if self.fp16 and "cuda" in str(device):
|
||||
self.diffusion_model = self.diffusion_model.half()
|
||||
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
||||
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
||||
|
||||
self.model = LatentDiffusion(self.diffusion_model, device)
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5)
|
||||
download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5)
|
||||
download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
model_paths = [
|
||||
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
||||
]
|
||||
return all([os.path.exists(it) for it in model_paths])
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# image [1,3,512,512] float32
|
||||
# mask: [1,1,512,512] float32
|
||||
# masked_image: [1,3,512,512] float32
|
||||
if config.ldm_sampler == LDMSampler.ddim:
|
||||
sampler = DDIMSampler(self.model)
|
||||
elif config.ldm_sampler == LDMSampler.plms:
|
||||
sampler = PLMSSampler(self.model)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
steps = config.ldm_steps
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
masked_image = (1 - mask) * image
|
||||
|
||||
mask = self._norm(mask)
|
||||
masked_image = self._norm(masked_image)
|
||||
|
||||
c = self.cond_stage_model_encode(masked_image)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||
|
||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||
samples_ddim = sampler.sample(
|
||||
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
x_samples_ddim = self.cond_stage_model_decode(
|
||||
samples_ddim
|
||||
) # samples_ddim: 1, 3, 128, 128 float32
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
# inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
||||
return inpainted_image
|
||||
|
||||
def _norm(self, tensor):
|
||||
return tensor * 2.0 - 1.0
|
@ -1,97 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model
|
||||
from .base import InpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
|
||||
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
|
||||
"MANGA_INPAINTOR_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
|
||||
)
|
||||
MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
|
||||
"MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
|
||||
)
|
||||
|
||||
MANGA_LINE_MODEL_URL = os.environ.get(
|
||||
"MANGA_LINE_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/manga/erika.jit",
|
||||
)
|
||||
MANGA_LINE_MODEL_MD5 = os.environ.get(
|
||||
"MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
|
||||
)
|
||||
|
||||
|
||||
class Manga(InpaintModel):
|
||||
name = "manga"
|
||||
pad_mod = 16
|
||||
is_erase_model = True
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.inpaintor_model = load_jit_model(
|
||||
MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
|
||||
)
|
||||
self.line_model = load_jit_model(
|
||||
MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
|
||||
)
|
||||
self.seed = 42
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5)
|
||||
download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
model_paths = [
|
||||
get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
|
||||
get_cache_path_by_url(MANGA_LINE_MODEL_URL),
|
||||
]
|
||||
return all([os.path.exists(it) for it in model_paths])
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
seed = self.seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
gray_img = torch.from_numpy(
|
||||
gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
|
||||
).to(self.device)
|
||||
start = time.time()
|
||||
lines = self.line_model(gray_img)
|
||||
torch.cuda.empty_cache()
|
||||
lines = torch.clamp(lines, 0, 255)
|
||||
logger.info(f"erika_model time: {time.time() - start}")
|
||||
|
||||
mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
|
||||
mask = mask.permute(0, 3, 1, 2)
|
||||
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
||||
noise = torch.randn_like(mask)
|
||||
ones = torch.ones_like(mask)
|
||||
|
||||
gray_img = gray_img / 255 * 2 - 1.0
|
||||
lines = lines / 255 * 2 - 1.0
|
||||
|
||||
start = time.time()
|
||||
inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
|
||||
logger.info(f"image_inpaintor_model time: {time.time() - start}")
|
||||
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
|
||||
return cur_res
|
1945
iopaint/model/mat.py
1945
iopaint/model/mat.py
File diff suppressed because it is too large
Load Diff
@ -1,110 +0,0 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from iopaint.helper import (
|
||||
load_jit_model,
|
||||
download_model,
|
||||
get_cache_path_by_url,
|
||||
boxes_from_mask,
|
||||
resize_max_size,
|
||||
norm_img,
|
||||
)
|
||||
from .base import InpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
MIGAN_MODEL_URL = os.environ.get(
|
||||
"MIGAN_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
|
||||
)
|
||||
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
|
||||
|
||||
|
||||
class MIGAN(InpaintModel):
|
||||
name = "migan"
|
||||
min_size = 512
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
is_erase_model = True
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: InpaintRequest):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
if image.shape[0] == 512 and image.shape[1] == 512:
|
||||
return self._pad_forward(image, mask, config)
|
||||
|
||||
boxes = boxes_from_mask(mask)
|
||||
crop_result = []
|
||||
config.hd_strategy_crop_margin = 128
|
||||
for box in boxes:
|
||||
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
|
||||
origin_size = crop_image.shape[:2]
|
||||
resize_image = resize_max_size(crop_image, size_limit=512)
|
||||
resize_mask = resize_max_size(crop_mask, size_limit=512)
|
||||
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
|
||||
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(
|
||||
inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
original_pixel_indices = crop_mask < 127
|
||||
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
|
||||
original_pixel_indices
|
||||
]
|
||||
|
||||
crop_result.append((inpaint_result, crop_box))
|
||||
|
||||
inpaint_result = image[:, :, ::-1].copy()
|
||||
for crop_image, crop_box in crop_result:
|
||||
x1, y1, x2, y2 = crop_box
|
||||
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W] mask area == 255
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
|
||||
image = norm_img(image) # [0, 1]
|
||||
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
||||
mask = (mask > 120) * 255
|
||||
mask = norm_img(mask)
|
||||
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
erased_img = image * (1 - mask)
|
||||
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
|
||||
|
||||
output = self.model(input_image)
|
||||
output = (
|
||||
(output.permute(0, 2, 3, 1) * 127.5 + 127.5)
|
||||
.round()
|
||||
.clamp(0, 255)
|
||||
.to(torch.uint8)
|
||||
)
|
||||
output = output[0].cpu().numpy()
|
||||
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return cur_res
|
@ -1,29 +0,0 @@
|
||||
import cv2
|
||||
from .base import InpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
||||
|
||||
|
||||
class OpenCV2(InpaintModel):
|
||||
name = "cv2"
|
||||
pad_mod = 1
|
||||
is_erase_model = True
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return True
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
cur_res = cv2.inpaint(
|
||||
image[:, :, ::-1],
|
||||
mask,
|
||||
inpaintRadius=config.cv2_radius,
|
||||
flags=flag_map[config.cv2_flag],
|
||||
)
|
||||
return cur_res
|
@ -1,19 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute()
|
||||
|
||||
|
||||
def get_config_files() -> Dict[str, Path]:
|
||||
"""
|
||||
- `v1`: Config file for Stable Diffusion v1
|
||||
- `v2`: Config file for Stable Diffusion v2
|
||||
- `xl`: Config file for Stable Diffusion XL
|
||||
- `xl_refiner`: Config file for Stable Diffusion XL Refiner
|
||||
"""
|
||||
return {
|
||||
"v1": CURRENT_DIR / "v1-inference.yaml",
|
||||
"v2": CURRENT_DIR / "v2-inference-v.yaml",
|
||||
"xl": CURRENT_DIR / "sd_xl_base.yaml",
|
||||
"xl_refiner": CURRENT_DIR / "sd_xl_refiner.yaml",
|
||||
}
|
@ -1,93 +0,0 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.13025
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
adm_in_channels: 2816
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: [1, 2, 10]
|
||||
context_dim: 2048
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: hidden
|
||||
layer_idx: 11
|
||||
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
params:
|
||||
arch: ViT-bigG-14
|
||||
version: laion2b_s39b_b160k
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
legacy: False
|
||||
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: target_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
@ -1,86 +0,0 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.13025
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
adm_in_channels: 2560
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 384
|
||||
attention_resolutions: [4, 2]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 4
|
||||
context_dim: [1280, 1280, 1280, 1280]
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
params:
|
||||
arch: ViT-bigG-14
|
||||
version: laion2b_s39b_b160k
|
||||
legacy: False
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: aesthetic_score
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
@ -1,70 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
@ -1,68 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
@ -1,68 +0,0 @@
|
||||
import PIL
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.helper import decode_base64_to_image
|
||||
from .base import DiffusionInpaintModel
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .utils import get_torch_dtype, enable_low_mem, is_local_files_only
|
||||
|
||||
|
||||
class PaintByExample(DiffusionInpaintModel):
|
||||
name = "Fantasy-Studio/Paint-by-Example"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
model_kwargs = {
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Paint By Example Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(safety_checker=None, requires_safety_checker=False)
|
||||
)
|
||||
|
||||
self.model = DiffusionPipeline.from_pretrained(
|
||||
self.name, torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
# TODO: gpu_id
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
self.model.image_encoder = self.model.image_encoder.to(device)
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
if config.paint_by_example_example_image is None:
|
||||
raise ValueError("paint_by_example_example_image is required")
|
||||
example_image, _, _ = decode_base64_to_image(
|
||||
config.paint_by_example_example_image
|
||||
)
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
example_image=PIL.Image.fromarray(example_image),
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
||||
output_type="np.array",
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
@ -1,225 +0,0 @@
|
||||
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||
import torch
|
||||
import numpy as np
|
||||
from .utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta, verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
steps,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
File diff suppressed because it is too large
Load Diff
@ -1,101 +0,0 @@
|
||||
from PIL import Image
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..base import DiffusionInpaintModel
|
||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
from iopaint.schema import InpaintRequest
|
||||
from .powerpaint_tokenizer import add_task_to_prompt
|
||||
from ...const import POWERPAINT_NAME
|
||||
|
||||
|
||||
class PowerPaint(DiffusionInpaintModel):
|
||||
name = POWERPAINT_NAME
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .pipeline_powerpaint import StableDiffusionInpaintPipeline
|
||||
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionInpaintPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.name,
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
|
||||
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
|
||||
config.prompt, config.negative_prompt, config.powerpaint_task
|
||||
)
|
||||
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
promptA=promptA,
|
||||
promptB=promptB,
|
||||
tradoff=config.fitting_degree,
|
||||
tradoff_nag=config.fitting_degree,
|
||||
negative_promptA=negative_promptA,
|
||||
negative_promptB=negative_promptB,
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
num_inference_steps=config.sd_steps,
|
||||
strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
@ -1,186 +0,0 @@
|
||||
from itertools import chain
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from iopaint.model.original_sd_configs import get_config_files
|
||||
from loguru import logger
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import numpy as np
|
||||
|
||||
from ..base import DiffusionInpaintModel
|
||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from ..utils import (
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
handle_from_pretrained_exceptions,
|
||||
)
|
||||
from .powerpaint_tokenizer import task_to_prompt
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
from .v2.BrushNet_CA import BrushNetModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel_forward
|
||||
from .v2.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D_forward,
|
||||
DownBlock2D_forward,
|
||||
CrossAttnUpBlock2D_forward,
|
||||
UpBlock2D_forward,
|
||||
)
|
||||
|
||||
|
||||
class PowerPaintV2(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
hf_model_id = "Sanster/PowerPaint_v2"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .v2.pipeline_PowerPaint_Brushnet_CA import (
|
||||
StableDiffusionPowerPaintBrushNetPipeline,
|
||||
)
|
||||
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
text_encoder_brushnet = CLIPTextModel.from_pretrained(
|
||||
self.hf_model_id,
|
||||
subfolder="text_encoder_brushnet",
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
self.hf_model_id,
|
||||
subfolder="PowerPaint_Brushnet",
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=False,
|
||||
original_config_file=get_config_files()["v1"],
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
pipe = handle_from_pretrained_exceptions(
|
||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
variant="fp16",
|
||||
**model_kwargs,
|
||||
)
|
||||
pipe.tokenizer = PowerPaintTokenizer(
|
||||
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
|
||||
)
|
||||
self.model = pipe
|
||||
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
|
||||
self.model.unet, self.model.unet.__class__
|
||||
)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in chain(
|
||||
self.model.unet.down_blocks, self.model.brushnet.down_blocks
|
||||
):
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
|
||||
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
image = image * (1 - mask / 255.0)
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
image = PIL.Image.fromarray(image.astype(np.uint8))
|
||||
mask = PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB")
|
||||
|
||||
promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(
|
||||
config.powerpaint_task
|
||||
)
|
||||
|
||||
output = self.model(
|
||||
image=image,
|
||||
mask=mask,
|
||||
promptA=promptA,
|
||||
promptB=promptB,
|
||||
promptU=config.prompt,
|
||||
tradoff=config.fitting_degree,
|
||||
tradoff_nag=config.fitting_degree,
|
||||
negative_promptA=negative_promptA,
|
||||
negative_promptB=negative_promptB,
|
||||
negative_promptU=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
# strength=config.sd_strength,
|
||||
brushnet_conditioning_scale=1.0,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
@ -1,254 +0,0 @@
|
||||
import copy
|
||||
import random
|
||||
from typing import Any, List, Union
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from iopaint.schema import PowerPaintTask
|
||||
|
||||
|
||||
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
|
||||
if task == PowerPaintTask.object_remove:
|
||||
promptA = prompt + " P_ctxt"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt + " P_obj"
|
||||
negative_promptB = negative_prompt + " P_obj"
|
||||
elif task == PowerPaintTask.context_aware:
|
||||
promptA = prompt + " P_ctxt"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt
|
||||
negative_promptB = negative_prompt
|
||||
elif task == PowerPaintTask.shape_guided:
|
||||
promptA = prompt + " P_shape"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt
|
||||
negative_promptB = negative_prompt
|
||||
elif task == PowerPaintTask.outpainting:
|
||||
promptA = prompt + " P_ctxt"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt + " P_obj"
|
||||
negative_promptB = negative_prompt + " P_obj"
|
||||
else:
|
||||
promptA = prompt + " P_obj"
|
||||
promptB = prompt + " P_obj"
|
||||
negative_promptA = negative_prompt
|
||||
negative_promptB = negative_prompt
|
||||
|
||||
return promptA, promptB, negative_promptA, negative_promptB
|
||||
|
||||
|
||||
def task_to_prompt(task: PowerPaintTask):
|
||||
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
|
||||
"", "", task
|
||||
)
|
||||
return (
|
||||
promptA.strip(),
|
||||
promptB.strip(),
|
||||
negative_promptA.strip(),
|
||||
negative_promptB.strip(),
|
||||
)
|
||||
|
||||
|
||||
class PowerPaintTokenizer:
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.wrapped = tokenizer
|
||||
self.token_map = {}
|
||||
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
|
||||
num_vec_per_token = 10
|
||||
for placeholder_token in placeholder_tokens:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
output.append(ith_token)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "wrapped":
|
||||
return super().__getattr__("wrapped")
|
||||
|
||||
try:
|
||||
return getattr(self.wrapped, name)
|
||||
except AttributeError:
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"'name' cannot be found in both "
|
||||
f"'{self.__class__.__name__}' and "
|
||||
f"'{self.__class__.__name__}.tokenizer'."
|
||||
)
|
||||
|
||||
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
||||
"""Attempt to add tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
tokens (Union[str, List[str]]): The tokens to be added.
|
||||
"""
|
||||
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
||||
assert num_added_tokens != 0, (
|
||||
f"The tokenizer already contains the token {tokens}. Please pass "
|
||||
"a different `placeholder_token` that is not already in the "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
def get_token_info(self, token: str) -> dict:
|
||||
"""Get the information of a token, including its start and end index in
|
||||
the current tokenizer.
|
||||
|
||||
Args:
|
||||
token (str): The token to be queried.
|
||||
|
||||
Returns:
|
||||
dict: The information of the token, including its start and end
|
||||
index in current tokenizer.
|
||||
"""
|
||||
token_ids = self.__call__(token).input_ids
|
||||
start, end = token_ids[1], token_ids[-2] + 1
|
||||
return {"name": token, "start": start, "end": end}
|
||||
|
||||
def add_placeholder_token(
|
||||
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
|
||||
):
|
||||
"""Add placeholder tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
placeholder_token (str): The placeholder token to be added.
|
||||
num_vec_per_token (int, optional): The number of vectors of
|
||||
the added placeholder token.
|
||||
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
||||
"""
|
||||
output = []
|
||||
if num_vec_per_token == 1:
|
||||
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
||||
output.append(placeholder_token)
|
||||
else:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
self.try_adding_tokens(ith_token, *args, **kwargs)
|
||||
output.append(ith_token)
|
||||
|
||||
for token in self.token_map:
|
||||
if token in placeholder_token:
|
||||
raise ValueError(
|
||||
f"The tokenizer already has placeholder token {token} "
|
||||
f"that can get confused with {placeholder_token} "
|
||||
"keep placeholder tokens independent"
|
||||
)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def replace_placeholder_tokens_in_text(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
vector_shuffle: bool = False,
|
||||
prop_tokens_to_load: float = 1.0,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Replace the keywords in text with placeholder tokens. This function
|
||||
will be called in `self.__call__` and `self.encode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(
|
||||
self.replace_placeholder_tokens_in_text(
|
||||
text[i], vector_shuffle=vector_shuffle
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
for placeholder_token in self.token_map:
|
||||
if placeholder_token in text:
|
||||
tokens = self.token_map[placeholder_token]
|
||||
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
||||
if vector_shuffle:
|
||||
tokens = copy.copy(tokens)
|
||||
random.shuffle(tokens)
|
||||
text = text.replace(placeholder_token, " ".join(tokens))
|
||||
return text
|
||||
|
||||
def replace_text_with_placeholder_tokens(
|
||||
self, text: Union[str, List[str]]
|
||||
) -> Union[str, List[str]]:
|
||||
"""Replace the placeholder tokens in text with the original keywords.
|
||||
This function will be called in `self.decode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
||||
return output
|
||||
|
||||
for placeholder_token, tokens in self.token_map.items():
|
||||
merged_tokens = " ".join(tokens)
|
||||
if merged_tokens in text:
|
||||
text = text.replace(merged_tokens, placeholder_token)
|
||||
return text
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
*args,
|
||||
vector_shuffle: bool = False,
|
||||
prop_tokens_to_load: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""The call function of the wrapper.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be tokenized.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(
|
||||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||
)
|
||||
|
||||
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
||||
|
||||
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
||||
"""Encode the passed text to token index.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be encode.
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
||||
return self.wrapped(replaced_text, *args, **kwargs)
|
||||
|
||||
def decode(
|
||||
self, token_ids, return_raw: bool = False, *args, **kwargs
|
||||
) -> Union[str, List[str]]:
|
||||
"""Decode the token index to text.
|
||||
|
||||
Args:
|
||||
token_ids: The token index to be decoded.
|
||||
return_raw: Whether keep the placeholder token in the text.
|
||||
Defaults to False.
|
||||
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The decoded text.
|
||||
"""
|
||||
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
||||
if return_raw:
|
||||
return text
|
||||
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
||||
return replaced_text
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,342 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.utils import is_torch_version, logging
|
||||
from diffusers.utils.torch_utils import apply_freeu
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def CrossAttnDownBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
lora_scale = (
|
||||
cross_attention_kwargs.get("scale", 1.0)
|
||||
if cross_attention_kwargs is not None
|
||||
else 1.0
|
||||
)
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = (
|
||||
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
hidden_states = hidden_states + additional_residuals
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(
|
||||
0
|
||||
) # todo: add before or after
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def DownBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(
|
||||
0
|
||||
) # todo: add before or after
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def CrossAttnUpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
return_res_samples: Optional[bool] = False,
|
||||
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = (
|
||||
cross_attention_kwargs.get("scale", 1.0)
|
||||
if cross_attention_kwargs is not None
|
||||
else 1.0
|
||||
)
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
if return_res_samples:
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = (
|
||||
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||
|
||||
if return_res_samples:
|
||||
return hidden_states, output_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def UpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
return_res_samples: Optional[bool] = False,
|
||||
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
if return_res_samples:
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(
|
||||
0
|
||||
) # todo: add before or after
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(
|
||||
0
|
||||
) # todo: add before or after
|
||||
|
||||
if return_res_samples:
|
||||
return hidden_states, output_states
|
||||
else:
|
||||
return hidden_states
|
@ -1,402 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def UNet2DConditionModel_forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DConditionModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||
example from ControlNet side model(s)
|
||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
for dim in sample.shape[-2:]:
|
||||
if dim % default_overall_up_factor != 0:
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
break
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (
|
||||
1 - encoder_attention_mask.to(sample.dtype)
|
||||
) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
||||
if class_emb is not None:
|
||||
if self.config.class_embeddings_concat:
|
||||
emb = torch.cat([emb, class_emb], dim=-1)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
aug_emb = self.get_aug_embed(
|
||||
emb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
)
|
||||
if self.config.addition_embed_type == "image_hint":
|
||||
aug_emb, hint = aug_emb
|
||||
sample = torch.cat([sample, hint], dim=1)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
encoder_hidden_states = self.process_encoder_hidden_states(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 2.5 GLIGEN position net
|
||||
if (
|
||||
cross_attention_kwargs is not None
|
||||
and cross_attention_kwargs.get("gligen", None) is not None
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
gligen_args = cross_attention_kwargs.pop("gligen")
|
||||
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||
|
||||
# 3. down
|
||||
lora_scale = (
|
||||
cross_attention_kwargs.get("scale", 1.0)
|
||||
if cross_attention_kwargs is not None
|
||||
else 1.0
|
||||
)
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = (
|
||||
mid_block_additional_residual is not None
|
||||
and down_block_additional_residuals is not None
|
||||
)
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
# maintain backward compatibility for legacy usage, where
|
||||
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||
# but can only use one or the other
|
||||
is_brushnet = (
|
||||
down_block_add_samples is not None
|
||||
and mid_block_add_sample is not None
|
||||
and up_block_add_samples is not None
|
||||
)
|
||||
if (
|
||||
not is_adapter
|
||||
and mid_block_additional_residual is None
|
||||
and down_block_additional_residuals is not None
|
||||
):
|
||||
deprecate(
|
||||
"T2I should not use down_block_additional_residuals",
|
||||
"1.3.0",
|
||||
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||
standard_warn=False,
|
||||
)
|
||||
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||
is_adapter = True
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
|
||||
if is_brushnet:
|
||||
sample = sample + down_block_add_samples.pop(0)
|
||||
|
||||
for downsample_block in self.down_blocks:
|
||||
if (
|
||||
hasattr(downsample_block, "has_cross_attention")
|
||||
and downsample_block.has_cross_attention
|
||||
):
|
||||
# For t2i-adapter CrossAttnDownBlock2D
|
||||
additional_residuals = {}
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = (
|
||||
down_intrablock_additional_residuals.pop(0)
|
||||
)
|
||||
|
||||
if is_brushnet and len(down_block_add_samples) > 0:
|
||||
additional_residuals["down_block_add_samples"] = [
|
||||
down_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(downsample_block.resnets)
|
||||
+ (downsample_block.downsamplers != None)
|
||||
)
|
||||
]
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(down_block_add_samples) > 0:
|
||||
additional_residuals["down_block_add_samples"] = [
|
||||
down_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(downsample_block.resnets)
|
||||
+ (downsample_block.downsamplers != None)
|
||||
)
|
||||
]
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
scale=lora_scale,
|
||||
**additional_residuals,
|
||||
)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if is_controlnet:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals
|
||||
):
|
||||
down_block_res_sample = (
|
||||
down_block_res_sample + down_block_additional_residual
|
||||
)
|
||||
new_down_block_res_samples = new_down_block_res_samples + (
|
||||
down_block_res_sample,
|
||||
)
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if (
|
||||
hasattr(self.mid_block, "has_cross_attention")
|
||||
and self.mid_block.has_cross_attention
|
||||
):
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_intrablock_additional_residuals) > 0
|
||||
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
if is_brushnet:
|
||||
sample = sample + mid_block_add_sample
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if (
|
||||
hasattr(upsample_block, "has_cross_attention")
|
||||
and upsample_block.has_cross_attention
|
||||
):
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(up_block_add_samples) > 0:
|
||||
additional_residuals["up_block_add_samples"] = [
|
||||
up_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(upsample_block.resnets)
|
||||
+ (upsample_block.upsamplers != None)
|
||||
)
|
||||
]
|
||||
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(up_block_add_samples) > 0:
|
||||
additional_residuals["up_block_add_samples"] = [
|
||||
up_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(upsample_block.resnets)
|
||||
+ (upsample_block.upsamplers != None)
|
||||
)
|
||||
]
|
||||
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
scale=lora_scale,
|
||||
**additional_residuals,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
if self.conv_norm_out:
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user