add diffusion progress

This commit is contained in:
Qing 2024-01-02 17:13:11 +08:00
parent f38be37f8c
commit 6253016019
17 changed files with 239 additions and 42 deletions

View File

@ -6,6 +6,9 @@ from pathlib import Path
from typing import Optional, Dict, List
import cv2
import socketio
import asyncio
from socketio import AsyncServer
import torch
import numpy as np
from loguru import logger
@ -109,6 +112,19 @@ def api_middleware(app: FastAPI):
app.add_middleware(CORSMiddleware, **cors_options)
global_sio: AsyncServer = None
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict):
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
return {}
class Api:
def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app
@ -134,6 +150,12 @@ class Api:
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
@ -206,6 +228,9 @@ class Api:
quality=self.config.quality,
infos=infos,
)
asyncio.run(self.sio.emit("diffusion_finish"))
return Response(
content=res_img_bytes,
media_type=f"image/{ext}",
@ -246,10 +271,10 @@ class Api:
def launch(self):
self.app.include_router(self.router)
uvicorn.run(
self.app,
self.combined_asgi_app,
host=self.config.host,
port=self.config.port,
timeout_keep_alive=60,
timeout_keep_alive=999999999,
)
def _build_file_manager(self) -> Optional[FileManager]:
@ -290,6 +315,7 @@ class Api:
disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder,
cpu_offload=self.config.cpu_offload,
callback=diffuser_callback,
)

View File

@ -38,7 +38,7 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False)
model_info = kwargs["model_info"]
model_info = kwargs["model_info"]
controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info
@ -154,7 +154,7 @@ class ControlNet(DiffusionInpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),

View File

@ -52,9 +52,8 @@ class Kandinsky(DiffusionInpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
generator=generator,
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -83,11 +83,10 @@ class SD(DiffusionInpaintModel):
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -2,7 +2,6 @@ import os
import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import AutoencoderKL
from loguru import logger
@ -79,11 +78,10 @@ class SDXL(DiffusionInpaintModel):
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -977,7 +977,6 @@ def handle_from_pretrained_exceptions(func, **kwargs):
try:
return func(**kwargs)
except ValueError as e:
# 处理异常的逻辑
if "You are trying to load the model files of the `variant=fp16`" in str(e):
logger.info("variant=fp16 not found, try revision=fp16")
return func(**{**kwargs, "variant": None, "revision": "fp16"})

View File

@ -18,7 +18,6 @@ def test_model_switch():
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
callback=None,
)
model.switch("lama")
@ -34,7 +33,6 @@ def test_controlnet_switch_onoff(caplog):
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
callback=None,
)
model.switch_controlnet_method(
@ -59,7 +57,6 @@ def test_switch_controlnet_method(caplog):
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
callback=None,
)
model.switch_controlnet_method(

View File

@ -30,15 +30,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
def test_outpainting(name, device, rect):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name=name,
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
prompt="a dog sitting on a bench in the park",
@ -72,15 +68,11 @@ def test_outpainting(name, device, rect):
def test_kandinsky_outpainting(name, device, rect):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name=name,
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
prompt="a cat",
@ -117,15 +109,11 @@ def test_kandinsky_outpainting(name, device, rect):
def test_powerpaint_outpainting(name, device, rect):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name=name,
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
prompt="a dog sitting on a bench in the park",

View File

@ -18,15 +18,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
def test_sdxl(device, strategy, sampler):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
strategy=strategy,
@ -54,15 +50,11 @@ def test_sdxl(device, strategy, sampler):
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
strategy=strategy,

View File

@ -5,8 +5,7 @@ transformers==4.34.1
safetensors
controlnet-aux==0.0.3
fastapi==0.108.0
python-multipart
simple-websocket
python-socketio==5.7.2
flaskwebgui==0.3.5
typer
pydantic

View File

@ -17,6 +17,7 @@
"@radix-ui/react-icons": "^1.3.0",
"@radix-ui/react-label": "^2.0.2",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-radio-group": "^1.1.3",
"@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0",
@ -48,6 +49,7 @@
"react-use": "^17.4.0",
"react-zoom-pan-pinch": "^3.3.0",
"recoil": "^0.7.7",
"socket.io-client": "^4.7.2",
"tailwind-merge": "^2.0.0",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",
@ -1914,6 +1916,30 @@
}
}
},
"node_modules/@radix-ui/react-progress": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.0.3.tgz",
"integrity": "sha512-5G6Om/tYSxjSeEdrb1VfKkfZfn/1IlPWd731h2RfPuSbIfNUgfqAwbKfJCg/PP6nuUCTrYzalwHSpSinoWoCag==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-primitive": "1.0.3"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-radio-group": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.1.3.tgz",
@ -2587,6 +2613,11 @@
"win32"
]
},
"node_modules/@socket.io/component-emitter": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz",
"integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg=="
},
"node_modules/@swc/core": {
"version": "1.3.96",
"resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.96.tgz",
@ -3804,7 +3835,6 @@
"version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
"dev": true,
"dependencies": {
"ms": "2.1.2"
},
@ -3876,6 +3906,26 @@
"integrity": "sha512-soytjxwbgcCu7nh5Pf4S2/4wa6UIu+A3p03U2yVr53qGxi1/VTR3ENI+p50v+UxqqZAfl48j3z55ud7VHIOr9w==",
"dev": true
},
"node_modules/engine.io-client": {
"version": "6.5.3",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz",
"integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.2.1",
"ws": "~8.11.0",
"xmlhttprequest-ssl": "~2.0.0"
}
},
"node_modules/engine.io-parser": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.1.tgz",
"integrity": "sha512-9JktcM3u18nU9N2Lz3bWeBgxVgOKpw7yhRaoxQA3FUDZzzw+9WlA6p4G4u0RixNkg14fH7EfEc/RhpurtiROTQ==",
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/error-stack-parser": {
"version": "2.1.4",
"resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.1.4.tgz",
@ -4813,8 +4863,7 @@
"node_modules/ms": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
"integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==",
"dev": true
"integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w=="
},
"node_modules/mz": {
"version": "2.7.0",
@ -5690,6 +5739,32 @@
"node": ">=8"
}
},
"node_modules/socket.io-client": {
"version": "4.7.2",
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.7.2.tgz",
"integrity": "sha512-vtA0uD4ibrYD793SOIAwlo8cj6haOeMHrGvwPxJsxH7CeIksqJ+3Zc06RvWTIFgiSqx4A3sOnTXpfAEE2Zyz6w==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.2",
"engine.io-client": "~6.5.2",
"socket.io-parser": "~4.2.4"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/socket.io-parser": {
"version": "4.2.4",
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz",
"integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/source-map": {
"version": "0.6.1",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
@ -6256,6 +6331,34 @@
"resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz",
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
},
"node_modules/ws": {
"version": "8.11.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/xmlhttprequest-ssl": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz",
"integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/yallist": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",

View File

@ -19,6 +19,7 @@
"@radix-ui/react-icons": "^1.3.0",
"@radix-ui/react-label": "^2.0.2",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-radio-group": "^1.1.3",
"@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0",
@ -50,6 +51,7 @@
"react-use": "^17.4.0",
"react-zoom-pan-pinch": "^3.3.0",
"recoil": "^0.7.7",
"socket.io-client": "^4.7.2",
"tailwind-merge": "^2.0.0",
"tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4",

View File

@ -0,0 +1,67 @@
import * as React from "react"
import io from "socket.io-client"
import { Progress } from "./ui/progress"
import { useStore } from "@/lib/states"
import { MODEL_TYPE_INPAINT } from "@/lib/const"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND
: ""
const socket = io(API_ENDPOINT)
const DiffusionProgress = () => {
const [settings, isInpainting] = useStore((state) => [
state.settings,
state.isInpainting,
])
const [isConnected, setIsConnected] = React.useState(false)
const [step, setStep] = React.useState(0)
const progress = Math.min(Math.round((step / settings.sdSteps) * 100), 100)
const isSD = settings.model.model_type !== MODEL_TYPE_INPAINT
React.useEffect(() => {
if (!isSD) {
return
}
socket.on("connect", () => {
setIsConnected(true)
})
socket.on("disconnect", () => {
setIsConnected(false)
})
socket.on("diffusion_progress", (data) => {
if (data) {
setStep(data.step + 1)
}
})
socket.on("diffusion_finish", (data) => {
setStep(0)
})
return () => {
socket.off("connect")
socket.off("disconnect")
socket.off("diffusion_progress")
socket.off("diffusion_finish")
}
}, [isSD])
return (
<div
className="fixed w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
style={{
visibility: isInpainting && isConnected && isSD ? "visible" : "hidden",
}}
>
<Progress value={progress} />
<div className="w-[45px] flex justify-center font-nums">{progress}%</div>
</div>
)
}
export default DiffusionProgress

View File

@ -6,6 +6,7 @@ import ImageSize from "./ImageSize"
import Plugins from "./Plugins"
import { InteractiveSeg } from "./InteractiveSeg"
import SidePanel from "./SidePanel"
import DiffusionProgress from "./DiffusionProgress"
const Workspace = () => {
const [file, updateSettings] = useStore((state) => [
@ -28,6 +29,7 @@ const Workspace = () => {
<ImageSize />
</div>
<InteractiveSeg />
<DiffusionProgress />
<SidePanel />
{file ? <Editor file={file} /> : <></>}
</>

View File

@ -0,0 +1,26 @@
import * as React from "react"
import * as ProgressPrimitive from "@radix-ui/react-progress"
import { cn } from "@/lib/utils"
const Progress = React.forwardRef<
React.ElementRef<typeof ProgressPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof ProgressPrimitive.Root>
>(({ className, value, ...props }, ref) => (
<ProgressPrimitive.Root
ref={ref}
className={cn(
"relative h-2 w-full overflow-hidden rounded-full bg-primary/20",
className
)}
{...props}
>
<ProgressPrimitive.Indicator
className="h-full w-full flex-1 bg-primary transition-all"
style={{ transform: `translateX(-${100 - (value || 0)}%)` }}
/>
</ProgressPrimitive.Root>
))
Progress.displayName = ProgressPrimitive.Root.displayName
export { Progress }

View File

@ -11,7 +11,7 @@ import { convertToBase64, srcToFile } from "@/lib/utils"
import axios from "axios"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND + "/api/v1"
: "/api/v1"
const api = axios.create({

View File

@ -312,7 +312,7 @@ const defaultValues: AppState = {
sdGuidanceScale: 7.5,
sdSampler: "DPM++ 2M",
sdMatchHistograms: false,
sdScale: 100,
sdScale: 1.0,
p2pImageGuidanceScale: 1.5,
controlnetConditioningScale: 0.4,
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",