From 6253016019b4eec651fec7a14a9f2c6b7018ef55 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Jan 2024 17:13:11 +0800 Subject: [PATCH] add diffusion progress --- lama_cleaner/api.py | 30 ++++- lama_cleaner/model/controlnet.py | 4 +- lama_cleaner/model/kandinsky.py | 3 +- lama_cleaner/model/sd.py | 3 +- lama_cleaner/model/sdxl.py | 4 +- lama_cleaner/model/utils.py | 1 - lama_cleaner/tests/test_model_switch.py | 3 - lama_cleaner/tests/test_outpainting.py | 12 -- lama_cleaner/tests/test_sdxl.py | 8 -- requirements.txt | 3 +- web_app/package-lock.json | 109 ++++++++++++++++++- web_app/package.json | 2 + web_app/src/components/DiffusionProgress.tsx | 67 ++++++++++++ web_app/src/components/Workspace.tsx | 2 + web_app/src/components/ui/progress.tsx | 26 +++++ web_app/src/lib/api.ts | 2 +- web_app/src/lib/states.ts | 2 +- 17 files changed, 239 insertions(+), 42 deletions(-) create mode 100644 web_app/src/components/DiffusionProgress.tsx create mode 100644 web_app/src/components/ui/progress.tsx diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py index a8fee6d..b5e180c 100644 --- a/lama_cleaner/api.py +++ b/lama_cleaner/api.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import Optional, Dict, List import cv2 +import socketio +import asyncio +from socketio import AsyncServer import torch import numpy as np from loguru import logger @@ -109,6 +112,19 @@ def api_middleware(app: FastAPI): app.add_middleware(CORSMiddleware, **cors_options) +global_sio: AsyncServer = None + + +def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict): + # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict + # logger.info(f"diffusion callback: step={step}, timestep={timestep}") + + # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI, + # but for now let's just start a separate event loop. It shouldn't make a difference for single person use + asyncio.run(global_sio.emit("diffusion_progress", {"step": step})) + return {} + + class Api: def __init__(self, app: FastAPI, config: ApiConfig): self.app = app @@ -134,6 +150,12 @@ class Api: self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") # fmt: on + global global_sio + self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) + self.app.mount("/ws", self.combined_asgi_app) + global_sio = self.sio + def add_api_route(self, path: str, endpoint, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) @@ -206,6 +228,9 @@ class Api: quality=self.config.quality, infos=infos, ) + + asyncio.run(self.sio.emit("diffusion_finish")) + return Response( content=res_img_bytes, media_type=f"image/{ext}", @@ -246,10 +271,10 @@ class Api: def launch(self): self.app.include_router(self.router) uvicorn.run( - self.app, + self.combined_asgi_app, host=self.config.host, port=self.config.port, - timeout_keep_alive=60, + timeout_keep_alive=999999999, ) def _build_file_manager(self) -> Optional[FileManager]: @@ -290,6 +315,7 @@ class Api: disable_nsfw=self.config.disable_nsfw_checker, sd_cpu_textencoder=self.config.cpu_textencoder, cpu_offload=self.config.cpu_offload, + callback=diffuser_callback, ) diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index cbe1e5f..727dc61 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -38,7 +38,7 @@ class ControlNet(DiffusionInpaintModel): def init_model(self, device: torch.device, **kwargs): fp16 = not kwargs.get("no_half", False) - model_info = kwargs["model_info"] + model_info = kwargs["model_info"] controlnet_method = kwargs["controlnet_method"] self.model_info = model_info @@ -154,7 +154,7 @@ class ControlNet(DiffusionInpaintModel): num_inference_steps=config.sd_steps, guidance_scale=config.sd_guidance_scale, output_type="np", - callback=self.callback, + callback_on_step_end=self.callback, height=img_h, width=img_w, generator=torch.manual_seed(config.sd_seed), diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index e5467e0..5bcfffd 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -52,9 +52,8 @@ class Kandinsky(DiffusionInpaintModel): num_inference_steps=config.sd_steps, guidance_scale=config.sd_guidance_scale, output_type="np", - callback=self.callback, + callback_on_step_end=self.callback, generator=generator, - callback_steps=1, ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 6f77abf..c3cad9e 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -83,11 +83,10 @@ class SD(DiffusionInpaintModel): strength=config.sd_strength, guidance_scale=config.sd_guidance_scale, output_type="np", - callback=self.callback, + callback_on_step_end=self.callback, height=img_h, width=img_w, generator=torch.manual_seed(config.sd_seed), - callback_steps=1, ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model/sdxl.py b/lama_cleaner/model/sdxl.py index f1bb60c..3d4a9b6 100644 --- a/lama_cleaner/model/sdxl.py +++ b/lama_cleaner/model/sdxl.py @@ -2,7 +2,6 @@ import os import PIL.Image import cv2 -import numpy as np import torch from diffusers import AutoencoderKL from loguru import logger @@ -79,11 +78,10 @@ class SDXL(DiffusionInpaintModel): strength=0.999 if config.sd_strength == 1.0 else config.sd_strength, guidance_scale=config.sd_guidance_scale, output_type="np", - callback=self.callback, + callback_on_step_end=self.callback, height=img_h, width=img_w, generator=torch.manual_seed(config.sd_seed), - callback_steps=1, ).images[0] output = (output * 255).round().astype("uint8") diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py index 535bacf..509ed9b 100644 --- a/lama_cleaner/model/utils.py +++ b/lama_cleaner/model/utils.py @@ -977,7 +977,6 @@ def handle_from_pretrained_exceptions(func, **kwargs): try: return func(**kwargs) except ValueError as e: - # 处理异常的逻辑 if "You are trying to load the model files of the `variant=fp16`" in str(e): logger.info("variant=fp16 not found, try revision=fp16") return func(**{**kwargs, "variant": None, "revision": "fp16"}) diff --git a/lama_cleaner/tests/test_model_switch.py b/lama_cleaner/tests/test_model_switch.py index 588bf82..a360566 100644 --- a/lama_cleaner/tests/test_model_switch.py +++ b/lama_cleaner/tests/test_model_switch.py @@ -18,7 +18,6 @@ def test_model_switch(): disable_nsfw=True, sd_cpu_textencoder=True, cpu_offload=False, - callback=None, ) model.switch("lama") @@ -34,7 +33,6 @@ def test_controlnet_switch_onoff(caplog): disable_nsfw=True, sd_cpu_textencoder=True, cpu_offload=False, - callback=None, ) model.switch_controlnet_method( @@ -59,7 +57,6 @@ def test_switch_controlnet_method(caplog): disable_nsfw=True, sd_cpu_textencoder=True, cpu_offload=False, - callback=None, ) model.switch_controlnet_method( diff --git a/lama_cleaner/tests/test_outpainting.py b/lama_cleaner/tests/test_outpainting.py index b47e1ee..f261b3c 100644 --- a/lama_cleaner/tests/test_outpainting.py +++ b/lama_cleaner/tests/test_outpainting.py @@ -30,15 +30,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal def test_outpainting(name, device, rect): sd_steps = check_device(device) - def callback(i, t, latents): - pass - model = ModelManager( name=name, device=torch.device(device), disable_nsfw=True, sd_cpu_textencoder=False, - callback=callback, ) cfg = get_config( prompt="a dog sitting on a bench in the park", @@ -72,15 +68,11 @@ def test_outpainting(name, device, rect): def test_kandinsky_outpainting(name, device, rect): sd_steps = check_device(device) - def callback(i, t, latents): - pass - model = ModelManager( name=name, device=torch.device(device), disable_nsfw=True, sd_cpu_textencoder=False, - callback=callback, ) cfg = get_config( prompt="a cat", @@ -117,15 +109,11 @@ def test_kandinsky_outpainting(name, device, rect): def test_powerpaint_outpainting(name, device, rect): sd_steps = check_device(device) - def callback(i, t, latents): - pass - model = ModelManager( name=name, device=torch.device(device), disable_nsfw=True, sd_cpu_textencoder=False, - callback=callback, ) cfg = get_config( prompt="a dog sitting on a bench in the park", diff --git a/lama_cleaner/tests/test_sdxl.py b/lama_cleaner/tests/test_sdxl.py index 7d0b5e3..506c963 100644 --- a/lama_cleaner/tests/test_sdxl.py +++ b/lama_cleaner/tests/test_sdxl.py @@ -18,15 +18,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal def test_sdxl(device, strategy, sampler): sd_steps = check_device(device) - def callback(i, t, latents): - pass - model = ModelManager( name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", device=torch.device(device), disable_nsfw=True, sd_cpu_textencoder=False, - callback=callback, ) cfg = get_config( strategy=strategy, @@ -54,15 +50,11 @@ def test_sdxl(device, strategy, sampler): def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler): sd_steps = check_device(device) - def callback(i, t, latents): - pass - model = ModelManager( name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", device=torch.device(device), disable_nsfw=True, sd_cpu_textencoder=False, - callback=callback, ) cfg = get_config( strategy=strategy, diff --git a/requirements.txt b/requirements.txt index 0c4d536..e62ff5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,7 @@ transformers==4.34.1 safetensors controlnet-aux==0.0.3 fastapi==0.108.0 -python-multipart -simple-websocket +python-socketio==5.7.2 flaskwebgui==0.3.5 typer pydantic diff --git a/web_app/package-lock.json b/web_app/package-lock.json index 11f63cd..408241e 100644 --- a/web_app/package-lock.json +++ b/web_app/package-lock.json @@ -17,6 +17,7 @@ "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.2", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.0.3", "@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-select": "^2.0.0", @@ -48,6 +49,7 @@ "react-use": "^17.4.0", "react-zoom-pan-pinch": "^3.3.0", "recoil": "^0.7.7", + "socket.io-client": "^4.7.2", "tailwind-merge": "^2.0.0", "tailwindcss-animate": "^1.0.7", "zod": "^3.22.4", @@ -1914,6 +1916,30 @@ } } }, + "node_modules/@radix-ui/react-progress": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.0.3.tgz", + "integrity": "sha512-5G6Om/tYSxjSeEdrb1VfKkfZfn/1IlPWd731h2RfPuSbIfNUgfqAwbKfJCg/PP6nuUCTrYzalwHSpSinoWoCag==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-radio-group": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.1.3.tgz", @@ -2587,6 +2613,11 @@ "win32" ] }, + "node_modules/@socket.io/component-emitter": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", + "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==" + }, "node_modules/@swc/core": { "version": "1.3.96", "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.96.tgz", @@ -3804,7 +3835,6 @@ "version": "4.3.4", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", - "dev": true, "dependencies": { "ms": "2.1.2" }, @@ -3876,6 +3906,26 @@ "integrity": "sha512-soytjxwbgcCu7nh5Pf4S2/4wa6UIu+A3p03U2yVr53qGxi1/VTR3ENI+p50v+UxqqZAfl48j3z55ud7VHIOr9w==", "dev": true }, + "node_modules/engine.io-client": { + "version": "6.5.3", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz", + "integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1", + "engine.io-parser": "~5.2.1", + "ws": "~8.11.0", + "xmlhttprequest-ssl": "~2.0.0" + } + }, + "node_modules/engine.io-parser": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.1.tgz", + "integrity": "sha512-9JktcM3u18nU9N2Lz3bWeBgxVgOKpw7yhRaoxQA3FUDZzzw+9WlA6p4G4u0RixNkg14fH7EfEc/RhpurtiROTQ==", + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/error-stack-parser": { "version": "2.1.4", "resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.1.4.tgz", @@ -4813,8 +4863,7 @@ "node_modules/ms": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", - "dev": true + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" }, "node_modules/mz": { "version": "2.7.0", @@ -5690,6 +5739,32 @@ "node": ">=8" } }, + "node_modules/socket.io-client": { + "version": "4.7.2", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.7.2.tgz", + "integrity": "sha512-vtA0uD4ibrYD793SOIAwlo8cj6haOeMHrGvwPxJsxH7CeIksqJ+3Zc06RvWTIFgiSqx4A3sOnTXpfAEE2Zyz6w==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.2", + "engine.io-client": "~6.5.2", + "socket.io-parser": "~4.2.4" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-parser": { + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/source-map": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", @@ -6256,6 +6331,34 @@ "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==" }, + "node_modules/ws": { + "version": "8.11.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", + "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": "^5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xmlhttprequest-ssl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz", + "integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==", + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", diff --git a/web_app/package.json b/web_app/package.json index 5f201e5..81f7926 100644 --- a/web_app/package.json +++ b/web_app/package.json @@ -19,6 +19,7 @@ "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.2", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.0.3", "@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-select": "^2.0.0", @@ -50,6 +51,7 @@ "react-use": "^17.4.0", "react-zoom-pan-pinch": "^3.3.0", "recoil": "^0.7.7", + "socket.io-client": "^4.7.2", "tailwind-merge": "^2.0.0", "tailwindcss-animate": "^1.0.7", "zod": "^3.22.4", diff --git a/web_app/src/components/DiffusionProgress.tsx b/web_app/src/components/DiffusionProgress.tsx new file mode 100644 index 0000000..9828c56 --- /dev/null +++ b/web_app/src/components/DiffusionProgress.tsx @@ -0,0 +1,67 @@ +import * as React from "react" +import io from "socket.io-client" +import { Progress } from "./ui/progress" +import { useStore } from "@/lib/states" +import { MODEL_TYPE_INPAINT } from "@/lib/const" + +export const API_ENDPOINT = import.meta.env.VITE_BACKEND + ? import.meta.env.VITE_BACKEND + : "" +const socket = io(API_ENDPOINT) + +const DiffusionProgress = () => { + const [settings, isInpainting] = useStore((state) => [ + state.settings, + state.isInpainting, + ]) + + const [isConnected, setIsConnected] = React.useState(false) + const [step, setStep] = React.useState(0) + + const progress = Math.min(Math.round((step / settings.sdSteps) * 100), 100) + const isSD = settings.model.model_type !== MODEL_TYPE_INPAINT + + React.useEffect(() => { + if (!isSD) { + return + } + socket.on("connect", () => { + setIsConnected(true) + }) + + socket.on("disconnect", () => { + setIsConnected(false) + }) + + socket.on("diffusion_progress", (data) => { + if (data) { + setStep(data.step + 1) + } + }) + + socket.on("diffusion_finish", (data) => { + setStep(0) + }) + + return () => { + socket.off("connect") + socket.off("disconnect") + socket.off("diffusion_progress") + socket.off("diffusion_finish") + } + }, [isSD]) + + return ( +
+ +
{progress}%
+
+ ) +} + +export default DiffusionProgress diff --git a/web_app/src/components/Workspace.tsx b/web_app/src/components/Workspace.tsx index bef1fda..cba0a15 100644 --- a/web_app/src/components/Workspace.tsx +++ b/web_app/src/components/Workspace.tsx @@ -6,6 +6,7 @@ import ImageSize from "./ImageSize" import Plugins from "./Plugins" import { InteractiveSeg } from "./InteractiveSeg" import SidePanel from "./SidePanel" +import DiffusionProgress from "./DiffusionProgress" const Workspace = () => { const [file, updateSettings] = useStore((state) => [ @@ -28,6 +29,7 @@ const Workspace = () => { + {file ? : <>} diff --git a/web_app/src/components/ui/progress.tsx b/web_app/src/components/ui/progress.tsx new file mode 100644 index 0000000..3fd47ad --- /dev/null +++ b/web_app/src/components/ui/progress.tsx @@ -0,0 +1,26 @@ +import * as React from "react" +import * as ProgressPrimitive from "@radix-ui/react-progress" + +import { cn } from "@/lib/utils" + +const Progress = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, value, ...props }, ref) => ( + + + +)) +Progress.displayName = ProgressPrimitive.Root.displayName + +export { Progress } diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 129bddd..3766e78 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -11,7 +11,7 @@ import { convertToBase64, srcToFile } from "@/lib/utils" import axios from "axios" export const API_ENDPOINT = import.meta.env.VITE_BACKEND - ? import.meta.env.VITE_BACKEND + ? import.meta.env.VITE_BACKEND + "/api/v1" : "/api/v1" const api = axios.create({ diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index ba7c5a7..6eba773 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -312,7 +312,7 @@ const defaultValues: AppState = { sdGuidanceScale: 7.5, sdSampler: "DPM++ 2M", sdMatchHistograms: false, - sdScale: 100, + sdScale: 1.0, p2pImageGuidanceScale: 1.5, controlnetConditioningScale: 0.4, controlnetMethod: "lllyasviel/control_v11p_sd15_canny",