switch controlnet in webui
This commit is contained in:
parent
0363472adc
commit
3eef8f4dae
@ -1,5 +1,5 @@
|
||||
import { PluginName } from '../components/Plugins/Plugins'
|
||||
import { Rect, Settings } from '../store/Atoms'
|
||||
import { ControlNetMethodMap, Rect, Settings } from '../store/Atoms'
|
||||
import { dataURItoBlob, loadImage, srcToFile } from '../utils'
|
||||
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
@ -92,6 +92,10 @@ export default async function inpaint(
|
||||
'controlnet_conditioning_scale',
|
||||
settings.controlnetConditioningScale.toString()
|
||||
)
|
||||
fd.append(
|
||||
'controlnet_method',
|
||||
ControlNetMethodMap[settings.controlnetMethod.toString()]
|
||||
)
|
||||
|
||||
try {
|
||||
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||
|
@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||
import { useToggle } from 'react-use'
|
||||
import {
|
||||
ControlNetMethod,
|
||||
isControlNetState,
|
||||
isInpaintingState,
|
||||
negativePropmtState,
|
||||
@ -47,6 +48,44 @@ const SidePanel = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const renderConterNetSetting = () => {
|
||||
return (
|
||||
<>
|
||||
<SettingBlock
|
||||
className="sub-setting-block"
|
||||
title="ControlNet"
|
||||
input={
|
||||
<Selector
|
||||
width={80}
|
||||
value={setting.controlnetMethod as string}
|
||||
options={Object.values(ControlNetMethod)}
|
||||
onChange={val => {
|
||||
const method = val as ControlNetMethod
|
||||
setSettingState(old => {
|
||||
return { ...old, controlnetMethod: method }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
|
||||
<NumberInputSetting
|
||||
title="ControlNet Weight"
|
||||
width={INPUT_WIDTH}
|
||||
allowFloat
|
||||
value={`${setting.controlnetConditioningScale}`}
|
||||
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||
setSettingState(old => {
|
||||
return { ...old, controlnetConditioningScale: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="side-panel">
|
||||
<PopoverPrimitive.Root open={open}>
|
||||
@ -58,6 +97,8 @@ const SidePanel = () => {
|
||||
</PopoverPrimitive.Trigger>
|
||||
<PopoverPrimitive.Portal>
|
||||
<PopoverPrimitive.Content className="side-panel-content">
|
||||
{isControlNet && renderConterNetSetting()}
|
||||
|
||||
<SettingBlock
|
||||
title="Croper"
|
||||
input={
|
||||
@ -117,22 +158,6 @@ const SidePanel = () => {
|
||||
}}
|
||||
/>
|
||||
|
||||
{isControlNet && (
|
||||
<NumberInputSetting
|
||||
title="ControlNet Weight"
|
||||
width={INPUT_WIDTH}
|
||||
allowFloat
|
||||
value={`${setting.controlnetConditioningScale}`}
|
||||
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||
setSettingState(old => {
|
||||
return { ...old, controlnetConditioningScale: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
<NumberInputSetting
|
||||
title="Mask Blur"
|
||||
width={INPUT_WIDTH}
|
||||
|
@ -440,6 +440,7 @@ export interface Settings {
|
||||
|
||||
// ControlNet
|
||||
controlnetConditioningScale: number
|
||||
controlnetMethod: string
|
||||
}
|
||||
|
||||
const defaultHDSettings: ModelsHDSettings = {
|
||||
@ -546,6 +547,18 @@ export enum SDSampler {
|
||||
uni_pc = 'uni_pc',
|
||||
}
|
||||
|
||||
export enum ControlNetMethod {
|
||||
canny = 'canny',
|
||||
inpaint = 'inpaint',
|
||||
openpose = 'openpose',
|
||||
}
|
||||
|
||||
export const ControlNetMethodMap: any = {
|
||||
canny: 'control_v11p_sd15_canny',
|
||||
inpaint: 'control_v11p_sd15_inpaint',
|
||||
openpose: 'control_v11p_sd15_openpose',
|
||||
}
|
||||
|
||||
export enum SDMode {
|
||||
text2img = 'text2img',
|
||||
img2img = 'img2img',
|
||||
@ -597,7 +610,8 @@ export const settingStateDefault: Settings = {
|
||||
p2pGuidanceScale: 7.5,
|
||||
|
||||
// ControlNet
|
||||
controlnetConditioningScale: 0.4,
|
||||
controlnetConditioningScale: 1.0,
|
||||
controlnetMethod: ControlNetMethod.canny,
|
||||
}
|
||||
|
||||
const localStorageEffect =
|
||||
|
@ -55,6 +55,7 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
||||
SD_CONTROLNET_HELP = """
|
||||
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
|
||||
"""
|
||||
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
|
||||
SD_CONTROLNET_CHOICES = [
|
||||
"control_v11p_sd15_canny",
|
||||
"control_v11p_sd15_openpose",
|
||||
@ -133,6 +134,7 @@ class Config(BaseModel):
|
||||
model: str = DEFAULT_MODEL
|
||||
sd_local_model_path: str = None
|
||||
sd_controlnet: bool = False
|
||||
sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
|
||||
device: str = DEFAULT_DEVICE
|
||||
gui: bool = False
|
||||
no_gui_auto_close: bool = False
|
||||
|
@ -6,7 +6,12 @@ import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device
|
||||
from lama_cleaner.helper import (
|
||||
boxes_from_mask,
|
||||
resize_max_size,
|
||||
pad_img_to_modulo,
|
||||
switch_mps_device,
|
||||
)
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
|
||||
@ -199,7 +204,9 @@ class InpaintModel:
|
||||
|
||||
# only calculate histograms for non-masked parts
|
||||
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
||||
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
|
||||
reference_histogram, _ = np.histogram(
|
||||
reference_channel[mask == 0], 256, [0, 256]
|
||||
)
|
||||
|
||||
source_cdf = self._calculate_cdf(source_histogram)
|
||||
reference_cdf = self._calculate_cdf(reference_histogram)
|
||||
@ -273,6 +280,7 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||
if config.sd_scale != 1:
|
||||
logger.info(
|
||||
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
||||
)
|
||||
@ -284,5 +292,7 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
original_pixel_indices = mask < 127
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
||||
original_pixel_indices
|
||||
]
|
||||
return inpaint_result
|
||||
|
@ -203,6 +203,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
negative_prompt=config.negative_prompt,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
output_type="np.array",
|
||||
callback=self.callback
|
||||
).images[0]
|
||||
else:
|
||||
if "canny" in self.sd_controlnet_method:
|
||||
|
@ -455,7 +455,7 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.controlnet.in_channels
|
||||
num_channels_latents = self.controlnet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import gc
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import SD15_MODELS
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model.controlnet import ControlNet
|
||||
@ -58,6 +60,7 @@ class ModelManager:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str, **kwargs):
|
||||
@ -86,7 +89,9 @@ class ModelManager:
|
||||
del self.model
|
||||
torch_gc()
|
||||
|
||||
old_method = self.kwargs["sd_controlnet_method"]
|
||||
self.kwargs["sd_controlnet_method"] = control_method
|
||||
self.model = self.init_model(
|
||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||
)
|
||||
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
|
||||
|
@ -40,7 +40,7 @@ def parse_args():
|
||||
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
|
||||
parser.add_argument(
|
||||
"--sd-controlnet-method",
|
||||
default="control_v11p_sd15_inpaint",
|
||||
default=DEFAULT_CONTROLNET_METHOD,
|
||||
choices=SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
|
||||
|
@ -4,3 +4,4 @@ from .realesrgan import RealESRGANUpscaler
|
||||
from .gfpgan_plugin import GFPGANPlugin
|
||||
from .restoreformer import RestoreFormerPlugin
|
||||
from .gif import MakeGIF
|
||||
from .anime_seg import AnimeSeg
|
||||
|
@ -96,4 +96,5 @@ class Config(BaseModel):
|
||||
p2p_guidance_scale: float = 7.5
|
||||
|
||||
# ControlNet
|
||||
controlnet_conditioning_scale: float = 0.4
|
||||
controlnet_conditioning_scale: float = 1.0
|
||||
controlnet_method: str = "control_v11p_sd15_canny"
|
||||
|
@ -1,9 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
||||
import hashlib
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
@ -32,6 +29,7 @@ from lama_cleaner.plugins import (
|
||||
MakeGIF,
|
||||
GFPGANPlugin,
|
||||
RestoreFormerPlugin,
|
||||
AnimeSeg,
|
||||
)
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@ -84,7 +82,15 @@ BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
||||
|
||||
class NoFlaskwebgui(logging.Filter):
|
||||
def filter(self, record):
|
||||
return "flaskwebgui-keep-server-alive" not in record.getMessage()
|
||||
msg = record.getMessage()
|
||||
if "Running on http:" in msg:
|
||||
print(msg[msg.index("Running on http:") :])
|
||||
|
||||
return (
|
||||
"flaskwebgui-keep-server-alive" not in msg
|
||||
and "socket.io" not in msg
|
||||
and "This is a development server." not in msg
|
||||
)
|
||||
|
||||
|
||||
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||
@ -92,6 +98,9 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
|
||||
sio_logger = logging.getLogger("sio-logger")
|
||||
sio_logger.setLevel(logging.ERROR)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
|
||||
|
||||
model: ModelManager = None
|
||||
@ -254,6 +263,7 @@ def process():
|
||||
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
||||
p2p_guidance_scale=form["p2pGuidanceScale"],
|
||||
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
||||
controlnet_method=form["controlnet_method"],
|
||||
)
|
||||
|
||||
if config.sd_seed == -1:
|
||||
@ -263,7 +273,6 @@ def process():
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
|
||||
@ -436,17 +445,6 @@ def switch_model():
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/controlnet_method", methods=["POST"])
|
||||
def switch_controlnet_method():
|
||||
new_method = request.form.get("method")
|
||||
|
||||
try:
|
||||
model.switch_controlnet_method(new_method)
|
||||
except NotImplementedError:
|
||||
return f"Failed switch to {new_method} not implemented", 500
|
||||
return f"Switch to {new_method}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
@ -603,4 +601,10 @@ def main(args):
|
||||
)
|
||||
ui.run()
|
||||
else:
|
||||
socketio.run(app, host=args.host, port=args.port, debug=args.debug)
|
||||
socketio.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
allow_unsafe_werkzeug=True,
|
||||
)
|
||||
|
@ -125,7 +125,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
sd_strength=1.0
|
||||
sd_strength=1.0,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
@ -138,3 +138,42 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
|
||||
def test_controlnet_switch(sd_device, sampler):
|
||||
if sd_device == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 1 if sd_device == "cpu" else 30
|
||||
model = ModelManager(
|
||||
name="sd1.5",
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=False,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
sd_controlnet_method="control_v11p_sd15_canny",
|
||||
)
|
||||
cfg = get_config(
|
||||
HDStrategy.ORIGINAL,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
controlnet_method="control_v11p_sd15_inpaint",
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_switch_to_inpaint_local_model_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
@ -35,6 +35,7 @@ def _save(img, name):
|
||||
def test_remove_bg():
|
||||
model = RemoveBG()
|
||||
res = model.forward(bgr_img)
|
||||
res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
|
||||
_save(res, "test_remove_bg.png")
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ def save_config(
|
||||
model,
|
||||
sd_local_model_path,
|
||||
sd_controlnet,
|
||||
sd_controlnet_method,
|
||||
device,
|
||||
gui,
|
||||
no_gui_auto_close,
|
||||
@ -182,6 +183,11 @@ def main(config_file: str):
|
||||
sd_controlnet = gr.Checkbox(
|
||||
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
|
||||
)
|
||||
sd_controlnet_method = gr.Radio(
|
||||
SD_CONTROLNET_CHOICES,
|
||||
lable="ControlNet method",
|
||||
value=init_config.sd_controlnet_method,
|
||||
)
|
||||
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
||||
cpu_offload = gr.Checkbox(
|
||||
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
|
||||
@ -207,6 +213,7 @@ def main(config_file: str):
|
||||
model,
|
||||
sd_local_model_path,
|
||||
sd_controlnet,
|
||||
sd_controlnet_method,
|
||||
device,
|
||||
gui,
|
||||
no_gui_auto_close,
|
||||
|
@ -2,6 +2,7 @@ torch>=1.9.0
|
||||
opencv-python
|
||||
flask==2.2.3
|
||||
flask-socketio
|
||||
simple-websocket
|
||||
flask_cors
|
||||
flaskwebgui==0.3.5
|
||||
pydantic
|
||||
|
Loading…
Reference in New Issue
Block a user