add diffusion progress

This commit is contained in:
Qing 2024-01-02 17:13:11 +08:00
parent f38be37f8c
commit 6253016019
17 changed files with 239 additions and 42 deletions

View File

@ -6,6 +6,9 @@ from pathlib import Path
from typing import Optional, Dict, List from typing import Optional, Dict, List
import cv2 import cv2
import socketio
import asyncio
from socketio import AsyncServer
import torch import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
@ -109,6 +112,19 @@ def api_middleware(app: FastAPI):
app.add_middleware(CORSMiddleware, **cors_options) app.add_middleware(CORSMiddleware, **cors_options)
global_sio: AsyncServer = None
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict):
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
return {}
class Api: class Api:
def __init__(self, app: FastAPI, config: ApiConfig): def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app self.app = app
@ -134,6 +150,12 @@ class Api:
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on # fmt: on
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
def add_api_route(self, path: str, endpoint, **kwargs): def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs) return self.app.add_api_route(path, endpoint, **kwargs)
@ -206,6 +228,9 @@ class Api:
quality=self.config.quality, quality=self.config.quality,
infos=infos, infos=infos,
) )
asyncio.run(self.sio.emit("diffusion_finish"))
return Response( return Response(
content=res_img_bytes, content=res_img_bytes,
media_type=f"image/{ext}", media_type=f"image/{ext}",
@ -246,10 +271,10 @@ class Api:
def launch(self): def launch(self):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run( uvicorn.run(
self.app, self.combined_asgi_app,
host=self.config.host, host=self.config.host,
port=self.config.port, port=self.config.port,
timeout_keep_alive=60, timeout_keep_alive=999999999,
) )
def _build_file_manager(self) -> Optional[FileManager]: def _build_file_manager(self) -> Optional[FileManager]:
@ -290,6 +315,7 @@ class Api:
disable_nsfw=self.config.disable_nsfw_checker, disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder, sd_cpu_textencoder=self.config.cpu_textencoder,
cpu_offload=self.config.cpu_offload, cpu_offload=self.config.cpu_offload,
callback=diffuser_callback,
) )

View File

@ -38,7 +38,7 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False) fp16 = not kwargs.get("no_half", False)
model_info = kwargs["model_info"] model_info = kwargs["model_info"]
controlnet_method = kwargs["controlnet_method"] controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info self.model_info = model_info
@ -154,7 +154,7 @@ class ControlNet(DiffusionInpaintModel):
num_inference_steps=config.sd_steps, num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback_on_step_end=self.callback,
height=img_h, height=img_h,
width=img_w, width=img_w,
generator=torch.manual_seed(config.sd_seed), generator=torch.manual_seed(config.sd_seed),

View File

@ -52,9 +52,8 @@ class Kandinsky(DiffusionInpaintModel):
num_inference_steps=config.sd_steps, num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback_on_step_end=self.callback,
generator=generator, generator=generator,
callback_steps=1,
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -83,11 +83,10 @@ class SD(DiffusionInpaintModel):
strength=config.sd_strength, strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback_on_step_end=self.callback,
height=img_h, height=img_h,
width=img_w, width=img_w,
generator=torch.manual_seed(config.sd_seed), generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -2,7 +2,6 @@ import os
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np
import torch import torch
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from loguru import logger from loguru import logger
@ -79,11 +78,10 @@ class SDXL(DiffusionInpaintModel):
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength, strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback_on_step_end=self.callback,
height=img_h, height=img_h,
width=img_w, width=img_w,
generator=torch.manual_seed(config.sd_seed), generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -977,7 +977,6 @@ def handle_from_pretrained_exceptions(func, **kwargs):
try: try:
return func(**kwargs) return func(**kwargs)
except ValueError as e: except ValueError as e:
# 处理异常的逻辑
if "You are trying to load the model files of the `variant=fp16`" in str(e): if "You are trying to load the model files of the `variant=fp16`" in str(e):
logger.info("variant=fp16 not found, try revision=fp16") logger.info("variant=fp16 not found, try revision=fp16")
return func(**{**kwargs, "variant": None, "revision": "fp16"}) return func(**{**kwargs, "variant": None, "revision": "fp16"})

View File

@ -18,7 +18,6 @@ def test_model_switch():
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=False, cpu_offload=False,
callback=None,
) )
model.switch("lama") model.switch("lama")
@ -34,7 +33,6 @@ def test_controlnet_switch_onoff(caplog):
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=False, cpu_offload=False,
callback=None,
) )
model.switch_controlnet_method( model.switch_controlnet_method(
@ -59,7 +57,6 @@ def test_switch_controlnet_method(caplog):
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
cpu_offload=False, cpu_offload=False,
callback=None,
) )
model.switch_controlnet_method( model.switch_controlnet_method(

View File

@ -30,15 +30,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
def test_outpainting(name, device, rect): def test_outpainting(name, device, rect):
sd_steps = check_device(device) sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager( model = ModelManager(
name=name, name=name,
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback,
) )
cfg = get_config( cfg = get_config(
prompt="a dog sitting on a bench in the park", prompt="a dog sitting on a bench in the park",
@ -72,15 +68,11 @@ def test_outpainting(name, device, rect):
def test_kandinsky_outpainting(name, device, rect): def test_kandinsky_outpainting(name, device, rect):
sd_steps = check_device(device) sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager( model = ModelManager(
name=name, name=name,
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback,
) )
cfg = get_config( cfg = get_config(
prompt="a cat", prompt="a cat",
@ -117,15 +109,11 @@ def test_kandinsky_outpainting(name, device, rect):
def test_powerpaint_outpainting(name, device, rect): def test_powerpaint_outpainting(name, device, rect):
sd_steps = check_device(device) sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager( model = ModelManager(
name=name, name=name,
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback,
) )
cfg = get_config( cfg = get_config(
prompt="a dog sitting on a bench in the park", prompt="a dog sitting on a bench in the park",

View File

@ -18,15 +18,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
def test_sdxl(device, strategy, sampler): def test_sdxl(device, strategy, sampler):
sd_steps = check_device(device) sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager( model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback,
) )
cfg = get_config( cfg = get_config(
strategy=strategy, strategy=strategy,
@ -54,15 +50,11 @@ def test_sdxl(device, strategy, sampler):
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler): def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
sd_steps = check_device(device) sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager( model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback,
) )
cfg = get_config( cfg = get_config(
strategy=strategy, strategy=strategy,

View File

@ -5,8 +5,7 @@ transformers==4.34.1
safetensors safetensors
controlnet-aux==0.0.3 controlnet-aux==0.0.3
fastapi==0.108.0 fastapi==0.108.0
python-multipart python-socketio==5.7.2
simple-websocket
flaskwebgui==0.3.5 flaskwebgui==0.3.5
typer typer
pydantic pydantic

View File

@ -17,6 +17,7 @@
"@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-icons": "^1.3.0",
"@radix-ui/react-label": "^2.0.2", "@radix-ui/react-label": "^2.0.2",
"@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-radio-group": "^1.1.3",
"@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0", "@radix-ui/react-select": "^2.0.0",
@ -48,6 +49,7 @@
"react-use": "^17.4.0", "react-use": "^17.4.0",
"react-zoom-pan-pinch": "^3.3.0", "react-zoom-pan-pinch": "^3.3.0",
"recoil": "^0.7.7", "recoil": "^0.7.7",
"socket.io-client": "^4.7.2",
"tailwind-merge": "^2.0.0", "tailwind-merge": "^2.0.0",
"tailwindcss-animate": "^1.0.7", "tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4", "zod": "^3.22.4",
@ -1914,6 +1916,30 @@
} }
} }
}, },
"node_modules/@radix-ui/react-progress": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.0.3.tgz",
"integrity": "sha512-5G6Om/tYSxjSeEdrb1VfKkfZfn/1IlPWd731h2RfPuSbIfNUgfqAwbKfJCg/PP6nuUCTrYzalwHSpSinoWoCag==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-primitive": "1.0.3"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-radio-group": { "node_modules/@radix-ui/react-radio-group": {
"version": "1.1.3", "version": "1.1.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.1.3.tgz", "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.1.3.tgz",
@ -2587,6 +2613,11 @@
"win32" "win32"
] ]
}, },
"node_modules/@socket.io/component-emitter": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz",
"integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg=="
},
"node_modules/@swc/core": { "node_modules/@swc/core": {
"version": "1.3.96", "version": "1.3.96",
"resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.96.tgz", "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.96.tgz",
@ -3804,7 +3835,6 @@
"version": "4.3.4", "version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
"dev": true,
"dependencies": { "dependencies": {
"ms": "2.1.2" "ms": "2.1.2"
}, },
@ -3876,6 +3906,26 @@
"integrity": "sha512-soytjxwbgcCu7nh5Pf4S2/4wa6UIu+A3p03U2yVr53qGxi1/VTR3ENI+p50v+UxqqZAfl48j3z55ud7VHIOr9w==", "integrity": "sha512-soytjxwbgcCu7nh5Pf4S2/4wa6UIu+A3p03U2yVr53qGxi1/VTR3ENI+p50v+UxqqZAfl48j3z55ud7VHIOr9w==",
"dev": true "dev": true
}, },
"node_modules/engine.io-client": {
"version": "6.5.3",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz",
"integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.2.1",
"ws": "~8.11.0",
"xmlhttprequest-ssl": "~2.0.0"
}
},
"node_modules/engine.io-parser": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.1.tgz",
"integrity": "sha512-9JktcM3u18nU9N2Lz3bWeBgxVgOKpw7yhRaoxQA3FUDZzzw+9WlA6p4G4u0RixNkg14fH7EfEc/RhpurtiROTQ==",
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/error-stack-parser": { "node_modules/error-stack-parser": {
"version": "2.1.4", "version": "2.1.4",
"resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.1.4.tgz", "resolved": "https://registry.npmjs.org/error-stack-parser/-/error-stack-parser-2.1.4.tgz",
@ -4813,8 +4863,7 @@
"node_modules/ms": { "node_modules/ms": {
"version": "2.1.2", "version": "2.1.2",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz",
"integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w=="
"dev": true
}, },
"node_modules/mz": { "node_modules/mz": {
"version": "2.7.0", "version": "2.7.0",
@ -5690,6 +5739,32 @@
"node": ">=8" "node": ">=8"
} }
}, },
"node_modules/socket.io-client": {
"version": "4.7.2",
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.7.2.tgz",
"integrity": "sha512-vtA0uD4ibrYD793SOIAwlo8cj6haOeMHrGvwPxJsxH7CeIksqJ+3Zc06RvWTIFgiSqx4A3sOnTXpfAEE2Zyz6w==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.2",
"engine.io-client": "~6.5.2",
"socket.io-parser": "~4.2.4"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/socket.io-parser": {
"version": "4.2.4",
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz",
"integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/source-map": { "node_modules/source-map": {
"version": "0.6.1", "version": "0.6.1",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
@ -6256,6 +6331,34 @@
"resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz",
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==" "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
}, },
"node_modules/ws": {
"version": "8.11.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/xmlhttprequest-ssl": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz",
"integrity": "sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/yallist": { "node_modules/yallist": {
"version": "4.0.0", "version": "4.0.0",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",

View File

@ -19,6 +19,7 @@
"@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-icons": "^1.3.0",
"@radix-ui/react-label": "^2.0.2", "@radix-ui/react-label": "^2.0.2",
"@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-radio-group": "^1.1.3",
"@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0", "@radix-ui/react-select": "^2.0.0",
@ -50,6 +51,7 @@
"react-use": "^17.4.0", "react-use": "^17.4.0",
"react-zoom-pan-pinch": "^3.3.0", "react-zoom-pan-pinch": "^3.3.0",
"recoil": "^0.7.7", "recoil": "^0.7.7",
"socket.io-client": "^4.7.2",
"tailwind-merge": "^2.0.0", "tailwind-merge": "^2.0.0",
"tailwindcss-animate": "^1.0.7", "tailwindcss-animate": "^1.0.7",
"zod": "^3.22.4", "zod": "^3.22.4",

View File

@ -0,0 +1,67 @@
import * as React from "react"
import io from "socket.io-client"
import { Progress } from "./ui/progress"
import { useStore } from "@/lib/states"
import { MODEL_TYPE_INPAINT } from "@/lib/const"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND
: ""
const socket = io(API_ENDPOINT)
const DiffusionProgress = () => {
const [settings, isInpainting] = useStore((state) => [
state.settings,
state.isInpainting,
])
const [isConnected, setIsConnected] = React.useState(false)
const [step, setStep] = React.useState(0)
const progress = Math.min(Math.round((step / settings.sdSteps) * 100), 100)
const isSD = settings.model.model_type !== MODEL_TYPE_INPAINT
React.useEffect(() => {
if (!isSD) {
return
}
socket.on("connect", () => {
setIsConnected(true)
})
socket.on("disconnect", () => {
setIsConnected(false)
})
socket.on("diffusion_progress", (data) => {
if (data) {
setStep(data.step + 1)
}
})
socket.on("diffusion_finish", (data) => {
setStep(0)
})
return () => {
socket.off("connect")
socket.off("disconnect")
socket.off("diffusion_progress")
socket.off("diffusion_finish")
}
}, [isSD])
return (
<div
className="fixed w-[220px] left-1/2 -translate-x-1/2 top-[68px] h-[32px] flex justify-center items-center gap-[18px] border-[1px] border-[solid] rounded-[14px] pl-[8px] pr-[8px]"
style={{
visibility: isInpainting && isConnected && isSD ? "visible" : "hidden",
}}
>
<Progress value={progress} />
<div className="w-[45px] flex justify-center font-nums">{progress}%</div>
</div>
)
}
export default DiffusionProgress

View File

@ -6,6 +6,7 @@ import ImageSize from "./ImageSize"
import Plugins from "./Plugins" import Plugins from "./Plugins"
import { InteractiveSeg } from "./InteractiveSeg" import { InteractiveSeg } from "./InteractiveSeg"
import SidePanel from "./SidePanel" import SidePanel from "./SidePanel"
import DiffusionProgress from "./DiffusionProgress"
const Workspace = () => { const Workspace = () => {
const [file, updateSettings] = useStore((state) => [ const [file, updateSettings] = useStore((state) => [
@ -28,6 +29,7 @@ const Workspace = () => {
<ImageSize /> <ImageSize />
</div> </div>
<InteractiveSeg /> <InteractiveSeg />
<DiffusionProgress />
<SidePanel /> <SidePanel />
{file ? <Editor file={file} /> : <></>} {file ? <Editor file={file} /> : <></>}
</> </>

View File

@ -0,0 +1,26 @@
import * as React from "react"
import * as ProgressPrimitive from "@radix-ui/react-progress"
import { cn } from "@/lib/utils"
const Progress = React.forwardRef<
React.ElementRef<typeof ProgressPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof ProgressPrimitive.Root>
>(({ className, value, ...props }, ref) => (
<ProgressPrimitive.Root
ref={ref}
className={cn(
"relative h-2 w-full overflow-hidden rounded-full bg-primary/20",
className
)}
{...props}
>
<ProgressPrimitive.Indicator
className="h-full w-full flex-1 bg-primary transition-all"
style={{ transform: `translateX(-${100 - (value || 0)}%)` }}
/>
</ProgressPrimitive.Root>
))
Progress.displayName = ProgressPrimitive.Root.displayName
export { Progress }

View File

@ -11,7 +11,7 @@ import { convertToBase64, srcToFile } from "@/lib/utils"
import axios from "axios" import axios from "axios"
export const API_ENDPOINT = import.meta.env.VITE_BACKEND export const API_ENDPOINT = import.meta.env.VITE_BACKEND
? import.meta.env.VITE_BACKEND ? import.meta.env.VITE_BACKEND + "/api/v1"
: "/api/v1" : "/api/v1"
const api = axios.create({ const api = axios.create({

View File

@ -312,7 +312,7 @@ const defaultValues: AppState = {
sdGuidanceScale: 7.5, sdGuidanceScale: 7.5,
sdSampler: "DPM++ 2M", sdSampler: "DPM++ 2M",
sdMatchHistograms: false, sdMatchHistograms: false,
sdScale: 100, sdScale: 1.0,
p2pImageGuidanceScale: 1.5, p2pImageGuidanceScale: 1.5,
controlnetConditioningScale: 0.4, controlnetConditioningScale: 0.4,
controlnetMethod: "lllyasviel/control_v11p_sd15_canny", controlnetMethod: "lllyasviel/control_v11p_sd15_canny",