switch controlnet in webui

This commit is contained in:
Qing 2023-05-13 13:45:27 +08:00
parent 0363472adc
commit 3eef8f4dae
16 changed files with 161 additions and 46 deletions

View File

@ -1,5 +1,5 @@
import { PluginName } from '../components/Plugins/Plugins'
import { Rect, Settings } from '../store/Atoms'
import { ControlNetMethodMap, Rect, Settings } from '../store/Atoms'
import { dataURItoBlob, loadImage, srcToFile } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -92,6 +92,10 @@ export default async function inpaint(
'controlnet_conditioning_scale',
settings.controlnetConditioningScale.toString()
)
fd.append(
'controlnet_method',
ControlNetMethodMap[settings.controlnetMethod.toString()]
)
try {
const res = await fetch(`${API_ENDPOINT}/inpaint`, {

View File

@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use'
import {
ControlNetMethod,
isControlNetState,
isInpaintingState,
negativePropmtState,
@ -47,6 +48,44 @@ const SidePanel = () => {
}
}
const renderConterNetSetting = () => {
return (
<>
<SettingBlock
className="sub-setting-block"
title="ControlNet"
input={
<Selector
width={80}
value={setting.controlnetMethod as string}
options={Object.values(ControlNetMethod)}
onChange={val => {
const method = val as ControlNetMethod
setSettingState(old => {
return { ...old, controlnetMethod: method }
})
}}
/>
}
/>
<NumberInputSetting
title="ControlNet Weight"
width={INPUT_WIDTH}
allowFloat
value={`${setting.controlnetConditioningScale}`}
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, controlnetConditioningScale: val }
})
}}
/>
</>
)
}
return (
<div className="side-panel">
<PopoverPrimitive.Root open={open}>
@ -58,6 +97,8 @@ const SidePanel = () => {
</PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content">
{isControlNet && renderConterNetSetting()}
<SettingBlock
title="Croper"
input={
@ -117,22 +158,6 @@ const SidePanel = () => {
}}
/>
{isControlNet && (
<NumberInputSetting
title="ControlNet Weight"
width={INPUT_WIDTH}
allowFloat
value={`${setting.controlnetConditioningScale}`}
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, controlnetConditioningScale: val }
})
}}
/>
)}
<NumberInputSetting
title="Mask Blur"
width={INPUT_WIDTH}

View File

@ -440,6 +440,7 @@ export interface Settings {
// ControlNet
controlnetConditioningScale: number
controlnetMethod: string
}
const defaultHDSettings: ModelsHDSettings = {
@ -546,6 +547,18 @@ export enum SDSampler {
uni_pc = 'uni_pc',
}
export enum ControlNetMethod {
canny = 'canny',
inpaint = 'inpaint',
openpose = 'openpose',
}
export const ControlNetMethodMap: any = {
canny: 'control_v11p_sd15_canny',
inpaint: 'control_v11p_sd15_inpaint',
openpose: 'control_v11p_sd15_openpose',
}
export enum SDMode {
text2img = 'text2img',
img2img = 'img2img',
@ -597,7 +610,8 @@ export const settingStateDefault: Settings = {
p2pGuidanceScale: 7.5,
// ControlNet
controlnetConditioningScale: 0.4,
controlnetConditioningScale: 1.0,
controlnetMethod: ControlNetMethod.canny,
}
const localStorageEffect =

View File

@ -55,6 +55,7 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
SD_CONTROLNET_HELP = """
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
"""
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
SD_CONTROLNET_CHOICES = [
"control_v11p_sd15_canny",
"control_v11p_sd15_openpose",
@ -133,6 +134,7 @@ class Config(BaseModel):
model: str = DEFAULT_MODEL
sd_local_model_path: str = None
sd_controlnet: bool = False
sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
device: str = DEFAULT_DEVICE
gui: bool = False
no_gui_auto_close: bool = False

View File

@ -6,7 +6,12 @@ import torch
import numpy as np
from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device
from lama_cleaner.helper import (
boxes_from_mask,
resize_max_size,
pad_img_to_modulo,
switch_mps_device,
)
from lama_cleaner.schema import Config, HDStrategy
@ -199,7 +204,9 @@ class InpaintModel:
# only calculate histograms for non-masked parts
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
reference_histogram, _ = np.histogram(
reference_channel[mask == 0], 256, [0, 256]
)
source_cdf = self._calculate_cdf(source_histogram)
reference_cdf = self._calculate_cdf(reference_histogram)
@ -273,6 +280,7 @@ class DiffusionInpaintModel(InpaintModel):
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
if config.sd_scale != 1:
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
@ -284,5 +292,7 @@ class DiffusionInpaintModel(InpaintModel):
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
original_pixel_indices
]
return inpaint_result

View File

@ -203,6 +203,7 @@ class ControlNet(DiffusionInpaintModel):
negative_prompt=config.negative_prompt,
generator=torch.manual_seed(config.sd_seed),
output_type="np.array",
callback=self.callback
).images[0]
else:
if "canny" in self.sd_controlnet_method:

View File

@ -455,7 +455,7 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.controlnet.in_channels
num_channels_latents = self.controlnet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@ -1,6 +1,8 @@
import torch
import gc
from loguru import logger
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.controlnet import ControlNet
@ -58,6 +60,7 @@ class ModelManager:
raise NotImplementedError(f"Not supported model: {name}")
def __call__(self, image, mask, config: Config):
self.switch_controlnet_method(control_method=config.controlnet_method)
return self.model(image, mask, config)
def switch(self, new_name: str, **kwargs):
@ -86,7 +89,9 @@ class ModelManager:
del self.model
torch_gc()
old_method = self.kwargs["sd_controlnet_method"]
self.kwargs["sd_controlnet_method"] = control_method
self.model = self.init_model(
self.name, switch_mps_device(self.name, self.device), **self.kwargs
)
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")

View File

@ -40,7 +40,7 @@ def parse_args():
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
parser.add_argument(
"--sd-controlnet-method",
default="control_v11p_sd15_inpaint",
default=DEFAULT_CONTROLNET_METHOD,
choices=SD_CONTROLNET_CHOICES,
)
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)

View File

@ -4,3 +4,4 @@ from .realesrgan import RealESRGANUpscaler
from .gfpgan_plugin import GFPGANPlugin
from .restoreformer import RestoreFormerPlugin
from .gif import MakeGIF
from .anime_seg import AnimeSeg

View File

@ -96,4 +96,5 @@ class Config(BaseModel):
p2p_guidance_scale: float = 7.5
# ControlNet
controlnet_conditioning_scale: float = 0.4
controlnet_conditioning_scale: float = 1.0
controlnet_method: str = "control_v11p_sd15_canny"

View File

@ -1,9 +1,6 @@
#!/usr/bin/env python3
import asyncio
import hashlib
import os
from lama_cleaner.plugins.anime_seg import AnimeSeg
import hashlib
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -32,6 +29,7 @@ from lama_cleaner.plugins import (
MakeGIF,
GFPGANPlugin,
RestoreFormerPlugin,
AnimeSeg,
)
from lama_cleaner.schema import Config
@ -84,7 +82,15 @@ BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
class NoFlaskwebgui(logging.Filter):
def filter(self, record):
return "flaskwebgui-keep-server-alive" not in record.getMessage()
msg = record.getMessage()
if "Running on http:" in msg:
print(msg[msg.index("Running on http:") :])
return (
"flaskwebgui-keep-server-alive" not in msg
and "socket.io" not in msg
and "This is a development server." not in msg
)
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
@ -92,6 +98,9 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"])
sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
model: ModelManager = None
@ -254,6 +263,7 @@ def process():
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
p2p_guidance_scale=form["p2pGuidanceScale"],
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
controlnet_method=form["controlnet_method"],
)
if config.sd_seed == -1:
@ -263,7 +273,6 @@ def process():
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
@ -436,17 +445,6 @@ def switch_model():
return f"ok, switch to {new_name}", 200
@app.route("/controlnet_method", methods=["POST"])
def switch_controlnet_method():
new_method = request.form.get("method")
try:
model.switch_controlnet_method(new_method)
except NotImplementedError:
return f"Failed switch to {new_method} not implemented", 500
return f"Switch to {new_method}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@ -603,4 +601,10 @@ def main(args):
)
ui.run()
else:
socketio.run(app, host=args.host, port=args.port, debug=args.debug)
socketio.run(
app,
host=args.host,
port=args.port,
debug=args.debug,
allow_unsafe_werkzeug=True,
)

View File

@ -125,7 +125,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
controlnet_conditioning_scale=1.0,
sd_strength=1.0
sd_strength=1.0,
)
cfg.sd_sampler = sampler
@ -138,3 +138,42 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
def test_controlnet_switch(sd_device, sampler):
if sd_device == "cuda" and not torch.cuda.is_available():
return
if device == "mps" and not torch.backends.mps.is_available():
return
sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=False,
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
sd_controlnet_method="control_v11p_sd15_canny",
)
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
controlnet_method="control_v11p_sd15_inpaint",
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"sd_controlnet_switch_to_inpaint_local_model_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -35,6 +35,7 @@ def _save(img, name):
def test_remove_bg():
model = RemoveBG()
res = model.forward(bgr_img)
res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
_save(res, "test_remove_bg.png")

View File

@ -16,6 +16,7 @@ def save_config(
model,
sd_local_model_path,
sd_controlnet,
sd_controlnet_method,
device,
gui,
no_gui_auto_close,
@ -182,6 +183,11 @@ def main(config_file: str):
sd_controlnet = gr.Checkbox(
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
)
sd_controlnet_method = gr.Radio(
SD_CONTROLNET_CHOICES,
lable="ControlNet method",
value=init_config.sd_controlnet_method,
)
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
cpu_offload = gr.Checkbox(
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
@ -207,6 +213,7 @@ def main(config_file: str):
model,
sd_local_model_path,
sd_controlnet,
sd_controlnet_method,
device,
gui,
no_gui_auto_close,

View File

@ -2,6 +2,7 @@ torch>=1.9.0
opencv-python
flask==2.2.3
flask-socketio
simple-websocket
flask_cors
flaskwebgui==0.3.5
pydantic