switch controlnet in webui

This commit is contained in:
Qing 2023-05-13 13:45:27 +08:00
parent 0363472adc
commit 3eef8f4dae
16 changed files with 161 additions and 46 deletions

View File

@ -1,5 +1,5 @@
import { PluginName } from '../components/Plugins/Plugins' import { PluginName } from '../components/Plugins/Plugins'
import { Rect, Settings } from '../store/Atoms' import { ControlNetMethodMap, Rect, Settings } from '../store/Atoms'
import { dataURItoBlob, loadImage, srcToFile } from '../utils' import { dataURItoBlob, loadImage, srcToFile } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -92,6 +92,10 @@ export default async function inpaint(
'controlnet_conditioning_scale', 'controlnet_conditioning_scale',
settings.controlnetConditioningScale.toString() settings.controlnetConditioningScale.toString()
) )
fd.append(
'controlnet_method',
ControlNetMethodMap[settings.controlnetMethod.toString()]
)
try { try {
const res = await fetch(`${API_ENDPOINT}/inpaint`, { const res = await fetch(`${API_ENDPOINT}/inpaint`, {

View File

@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover' import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use' import { useToggle } from 'react-use'
import { import {
ControlNetMethod,
isControlNetState, isControlNetState,
isInpaintingState, isInpaintingState,
negativePropmtState, negativePropmtState,
@ -47,6 +48,44 @@ const SidePanel = () => {
} }
} }
const renderConterNetSetting = () => {
return (
<>
<SettingBlock
className="sub-setting-block"
title="ControlNet"
input={
<Selector
width={80}
value={setting.controlnetMethod as string}
options={Object.values(ControlNetMethod)}
onChange={val => {
const method = val as ControlNetMethod
setSettingState(old => {
return { ...old, controlnetMethod: method }
})
}}
/>
}
/>
<NumberInputSetting
title="ControlNet Weight"
width={INPUT_WIDTH}
allowFloat
value={`${setting.controlnetConditioningScale}`}
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, controlnetConditioningScale: val }
})
}}
/>
</>
)
}
return ( return (
<div className="side-panel"> <div className="side-panel">
<PopoverPrimitive.Root open={open}> <PopoverPrimitive.Root open={open}>
@ -58,6 +97,8 @@ const SidePanel = () => {
</PopoverPrimitive.Trigger> </PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal> <PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content"> <PopoverPrimitive.Content className="side-panel-content">
{isControlNet && renderConterNetSetting()}
<SettingBlock <SettingBlock
title="Croper" title="Croper"
input={ input={
@ -117,22 +158,6 @@ const SidePanel = () => {
}} }}
/> />
{isControlNet && (
<NumberInputSetting
title="ControlNet Weight"
width={INPUT_WIDTH}
allowFloat
value={`${setting.controlnetConditioningScale}`}
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, controlnetConditioningScale: val }
})
}}
/>
)}
<NumberInputSetting <NumberInputSetting
title="Mask Blur" title="Mask Blur"
width={INPUT_WIDTH} width={INPUT_WIDTH}

View File

@ -440,6 +440,7 @@ export interface Settings {
// ControlNet // ControlNet
controlnetConditioningScale: number controlnetConditioningScale: number
controlnetMethod: string
} }
const defaultHDSettings: ModelsHDSettings = { const defaultHDSettings: ModelsHDSettings = {
@ -546,6 +547,18 @@ export enum SDSampler {
uni_pc = 'uni_pc', uni_pc = 'uni_pc',
} }
export enum ControlNetMethod {
canny = 'canny',
inpaint = 'inpaint',
openpose = 'openpose',
}
export const ControlNetMethodMap: any = {
canny: 'control_v11p_sd15_canny',
inpaint: 'control_v11p_sd15_inpaint',
openpose: 'control_v11p_sd15_openpose',
}
export enum SDMode { export enum SDMode {
text2img = 'text2img', text2img = 'text2img',
img2img = 'img2img', img2img = 'img2img',
@ -597,7 +610,8 @@ export const settingStateDefault: Settings = {
p2pGuidanceScale: 7.5, p2pGuidanceScale: 7.5,
// ControlNet // ControlNet
controlnetConditioningScale: 0.4, controlnetConditioningScale: 1.0,
controlnetMethod: ControlNetMethod.canny,
} }
const localStorageEffect = const localStorageEffect =

View File

@ -55,6 +55,7 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
SD_CONTROLNET_HELP = """ SD_CONTROLNET_HELP = """
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
""" """
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
SD_CONTROLNET_CHOICES = [ SD_CONTROLNET_CHOICES = [
"control_v11p_sd15_canny", "control_v11p_sd15_canny",
"control_v11p_sd15_openpose", "control_v11p_sd15_openpose",
@ -133,6 +134,7 @@ class Config(BaseModel):
model: str = DEFAULT_MODEL model: str = DEFAULT_MODEL
sd_local_model_path: str = None sd_local_model_path: str = None
sd_controlnet: bool = False sd_controlnet: bool = False
sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
device: str = DEFAULT_DEVICE device: str = DEFAULT_DEVICE
gui: bool = False gui: bool = False
no_gui_auto_close: bool = False no_gui_auto_close: bool = False

View File

@ -6,7 +6,12 @@ import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device from lama_cleaner.helper import (
boxes_from_mask,
resize_max_size,
pad_img_to_modulo,
switch_mps_device,
)
from lama_cleaner.schema import Config, HDStrategy from lama_cleaner.schema import Config, HDStrategy
@ -199,7 +204,9 @@ class InpaintModel:
# only calculate histograms for non-masked parts # only calculate histograms for non-masked parts
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256]) reference_histogram, _ = np.histogram(
reference_channel[mask == 0], 256, [0, 256]
)
source_cdf = self._calculate_cdf(source_histogram) source_cdf = self._calculate_cdf(source_histogram)
reference_cdf = self._calculate_cdf(reference_histogram) reference_cdf = self._calculate_cdf(reference_histogram)
@ -273,9 +280,10 @@ class DiffusionInpaintModel(InpaintModel):
origin_size = image.shape[:2] origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length) downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length) downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info( if config.sd_scale != 1:
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" logger.info(
) f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result # only paste masked area result
inpaint_result = cv2.resize( inpaint_result = cv2.resize(
@ -284,5 +292,7 @@ class DiffusionInpaintModel(InpaintModel):
interpolation=cv2.INTER_CUBIC, interpolation=cv2.INTER_CUBIC,
) )
original_pixel_indices = mask < 127 original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] inpaint_result[original_pixel_indices] = image[:, :, ::-1][
original_pixel_indices
]
return inpaint_result return inpaint_result

View File

@ -203,6 +203,7 @@ class ControlNet(DiffusionInpaintModel):
negative_prompt=config.negative_prompt, negative_prompt=config.negative_prompt,
generator=torch.manual_seed(config.sd_seed), generator=torch.manual_seed(config.sd_seed),
output_type="np.array", output_type="np.array",
callback=self.callback
).images[0] ).images[0]
else: else:
if "canny" in self.sd_controlnet_method: if "canny" in self.sd_controlnet_method:

View File

@ -455,7 +455,7 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.controlnet.in_channels num_channels_latents = self.controlnet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,

View File

@ -1,6 +1,8 @@
import torch import torch
import gc import gc
from loguru import logger
from lama_cleaner.const import SD15_MODELS from lama_cleaner.const import SD15_MODELS
from lama_cleaner.helper import switch_mps_device from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.controlnet import ControlNet from lama_cleaner.model.controlnet import ControlNet
@ -58,6 +60,7 @@ class ModelManager:
raise NotImplementedError(f"Not supported model: {name}") raise NotImplementedError(f"Not supported model: {name}")
def __call__(self, image, mask, config: Config): def __call__(self, image, mask, config: Config):
self.switch_controlnet_method(control_method=config.controlnet_method)
return self.model(image, mask, config) return self.model(image, mask, config)
def switch(self, new_name: str, **kwargs): def switch(self, new_name: str, **kwargs):
@ -86,7 +89,9 @@ class ModelManager:
del self.model del self.model
torch_gc() torch_gc()
old_method = self.kwargs["sd_controlnet_method"]
self.kwargs["sd_controlnet_method"] = control_method self.kwargs["sd_controlnet_method"] = control_method
self.model = self.init_model( self.model = self.init_model(
self.name, switch_mps_device(self.name, self.device), **self.kwargs self.name, switch_mps_device(self.name, self.device), **self.kwargs
) )
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")

View File

@ -40,7 +40,7 @@ def parse_args():
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
parser.add_argument( parser.add_argument(
"--sd-controlnet-method", "--sd-controlnet-method",
default="control_v11p_sd15_inpaint", default=DEFAULT_CONTROLNET_METHOD,
choices=SD_CONTROLNET_CHOICES, choices=SD_CONTROLNET_CHOICES,
) )
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)

View File

@ -4,3 +4,4 @@ from .realesrgan import RealESRGANUpscaler
from .gfpgan_plugin import GFPGANPlugin from .gfpgan_plugin import GFPGANPlugin
from .restoreformer import RestoreFormerPlugin from .restoreformer import RestoreFormerPlugin
from .gif import MakeGIF from .gif import MakeGIF
from .anime_seg import AnimeSeg

View File

@ -96,4 +96,5 @@ class Config(BaseModel):
p2p_guidance_scale: float = 7.5 p2p_guidance_scale: float = 7.5
# ControlNet # ControlNet
controlnet_conditioning_scale: float = 0.4 controlnet_conditioning_scale: float = 1.0
controlnet_method: str = "control_v11p_sd15_canny"

View File

@ -1,9 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio
import hashlib
import os import os
import hashlib
from lama_cleaner.plugins.anime_seg import AnimeSeg
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -32,6 +29,7 @@ from lama_cleaner.plugins import (
MakeGIF, MakeGIF,
GFPGANPlugin, GFPGANPlugin,
RestoreFormerPlugin, RestoreFormerPlugin,
AnimeSeg,
) )
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
@ -84,7 +82,15 @@ BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
class NoFlaskwebgui(logging.Filter): class NoFlaskwebgui(logging.Filter):
def filter(self, record): def filter(self, record):
return "flaskwebgui-keep-server-alive" not in record.getMessage() msg = record.getMessage()
if "Running on http:" in msg:
print(msg[msg.index("Running on http:") :])
return (
"flaskwebgui-keep-server-alive" not in msg
and "socket.io" not in msg
and "This is a development server." not in msg
)
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
@ -92,6 +98,9 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"]) CORS(app, expose_headers=["Content-Disposition"])
sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
model: ModelManager = None model: ModelManager = None
@ -254,6 +263,7 @@ def process():
p2p_image_guidance_scale=form["p2pImageGuidanceScale"], p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
p2p_guidance_scale=form["p2pGuidanceScale"], p2p_guidance_scale=form["p2pGuidanceScale"],
controlnet_conditioning_scale=form["controlnet_conditioning_scale"], controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
controlnet_method=form["controlnet_method"],
) )
if config.sd_seed == -1: if config.sd_seed == -1:
@ -263,7 +273,6 @@ def process():
logger.info(f"Origin image shape: {original_shape}") logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
@ -436,17 +445,6 @@ def switch_model():
return f"ok, switch to {new_name}", 200 return f"ok, switch to {new_name}", 200
@app.route("/controlnet_method", methods=["POST"])
def switch_controlnet_method():
new_method = request.form.get("method")
try:
model.switch_controlnet_method(new_method)
except NotImplementedError:
return f"Failed switch to {new_method} not implemented", 500
return f"Switch to {new_method}", 200
@app.route("/") @app.route("/")
def index(): def index():
return send_file(os.path.join(BUILD_DIR, "index.html")) return send_file(os.path.join(BUILD_DIR, "index.html"))
@ -603,4 +601,10 @@ def main(args):
) )
ui.run() ui.run()
else: else:
socketio.run(app, host=args.host, port=args.port, debug=args.debug) socketio.run(
app,
host=args.host,
port=args.port,
debug=args.debug,
allow_unsafe_werkzeug=True,
)

View File

@ -125,7 +125,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
prompt="a fox sitting on a bench", prompt="a fox sitting on a bench",
sd_steps=sd_steps, sd_steps=sd_steps,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
sd_strength=1.0 sd_strength=1.0,
) )
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
@ -138,3 +138,42 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
) )
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("sampler", [SDSampler.uni_pc])
def test_controlnet_switch(sd_device, sampler):
if sd_device == "cuda" and not torch.cuda.is_available():
return
if device == "mps" and not torch.backends.mps.is_available():
return
sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=False,
disable_nsfw=True,
sd_cpu_textencoder=False,
cpu_offload=True,
sd_controlnet_method="control_v11p_sd15_canny",
)
cfg = get_config(
HDStrategy.ORIGINAL,
prompt="a fox sitting on a bench",
sd_steps=sd_steps,
controlnet_method="control_v11p_sd15_inpaint",
)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"sd_controlnet_switch_to_inpaint_local_model_{name}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -35,6 +35,7 @@ def _save(img, name):
def test_remove_bg(): def test_remove_bg():
model = RemoveBG() model = RemoveBG()
res = model.forward(bgr_img) res = model.forward(bgr_img)
res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
_save(res, "test_remove_bg.png") _save(res, "test_remove_bg.png")

View File

@ -16,6 +16,7 @@ def save_config(
model, model,
sd_local_model_path, sd_local_model_path,
sd_controlnet, sd_controlnet,
sd_controlnet_method,
device, device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,
@ -182,6 +183,11 @@ def main(config_file: str):
sd_controlnet = gr.Checkbox( sd_controlnet = gr.Checkbox(
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}" init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
) )
sd_controlnet_method = gr.Radio(
SD_CONTROLNET_CHOICES,
lable="ControlNet method",
value=init_config.sd_controlnet_method,
)
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
cpu_offload = gr.Checkbox( cpu_offload = gr.Checkbox(
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}" init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
@ -207,6 +213,7 @@ def main(config_file: str):
model, model,
sd_local_model_path, sd_local_model_path,
sd_controlnet, sd_controlnet,
sd_controlnet_method,
device, device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,

View File

@ -2,6 +2,7 @@ torch>=1.9.0
opencv-python opencv-python
flask==2.2.3 flask==2.2.3
flask-socketio flask-socketio
simple-websocket
flask_cors flask_cors
flaskwebgui==0.3.5 flaskwebgui==0.3.5
pydantic pydantic