From 3eef8f4dae1f2d0ac43438f40d12469f2d2d5cc5 Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 13 May 2023 13:45:27 +0800 Subject: [PATCH] switch controlnet in webui --- lama_cleaner/app/src/adapters/inpainting.ts | 6 +- .../src/components/SidePanel/SidePanel.tsx | 57 +++++++++++++------ lama_cleaner/app/src/store/Atoms.tsx | 16 +++++- lama_cleaner/const.py | 2 + lama_cleaner/model/base.py | 22 +++++-- lama_cleaner/model/controlnet.py | 1 + ...ine_stable_diffusion_controlnet_inpaint.py | 2 +- lama_cleaner/model_manager.py | 5 ++ lama_cleaner/parse_args.py | 2 +- lama_cleaner/plugins/__init__.py | 1 + lama_cleaner/schema.py | 3 +- lama_cleaner/server.py | 40 +++++++------ lama_cleaner/tests/test_controlnet.py | 41 ++++++++++++- lama_cleaner/tests/test_plugins.py | 1 + lama_cleaner/web_config.py | 7 +++ requirements.txt | 1 + 16 files changed, 161 insertions(+), 46 deletions(-) diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index bc64d98..3e78f93 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -1,5 +1,5 @@ import { PluginName } from '../components/Plugins/Plugins' -import { Rect, Settings } from '../store/Atoms' +import { ControlNetMethodMap, Rect, Settings } from '../store/Atoms' import { dataURItoBlob, loadImage, srcToFile } from '../utils' export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` @@ -92,6 +92,10 @@ export default async function inpaint( 'controlnet_conditioning_scale', settings.controlnetConditioningScale.toString() ) + fd.append( + 'controlnet_method', + ControlNetMethodMap[settings.controlnetMethod.toString()] + ) try { const res = await fetch(`${API_ENDPOINT}/inpaint`, { diff --git a/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx b/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx index 97d8049..668260f 100644 --- a/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx +++ b/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx @@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil' import * as PopoverPrimitive from '@radix-ui/react-popover' import { useToggle } from 'react-use' import { + ControlNetMethod, isControlNetState, isInpaintingState, negativePropmtState, @@ -47,6 +48,44 @@ const SidePanel = () => { } } + const renderConterNetSetting = () => { + return ( + <> + { + const method = val as ControlNetMethod + setSettingState(old => { + return { ...old, controlnetMethod: method } + }) + }} + /> + } + /> + + { + const val = value.length === 0 ? 0 : parseFloat(value) + setSettingState(old => { + return { ...old, controlnetConditioningScale: val } + }) + }} + /> + + ) + } + return (
@@ -58,6 +97,8 @@ const SidePanel = () => { + {isControlNet && renderConterNetSetting()} + { }} /> - {isControlNet && ( - { - const val = value.length === 0 ? 0 : parseFloat(value) - setSettingState(old => { - return { ...old, controlnetConditioningScale: val } - }) - }} - /> - )} - {downsize_image.shape}" - ) + if config.sd_scale != 1: + logger.info( + f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}" + ) inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) # only paste masked area result inpaint_result = cv2.resize( @@ -284,5 +292,7 @@ class DiffusionInpaintModel(InpaintModel): interpolation=cv2.INTER_CUBIC, ) original_pixel_indices = mask < 127 - inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] + inpaint_result[original_pixel_indices] = image[:, :, ::-1][ + original_pixel_indices + ] return inpaint_result diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index cb0a963..fa2060c 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -203,6 +203,7 @@ class ControlNet(DiffusionInpaintModel): negative_prompt=config.negative_prompt, generator=torch.manual_seed(config.sd_seed), output_type="np.array", + callback=self.callback ).images[0] else: if "canny" in self.sd_controlnet_method: diff --git a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py index b2a181c..f5883b8 100644 --- a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py +++ b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py @@ -455,7 +455,7 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.controlnet.in_channels + num_channels_latents = self.controlnet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 74330bb..20c61a7 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -1,6 +1,8 @@ import torch import gc +from loguru import logger + from lama_cleaner.const import SD15_MODELS from lama_cleaner.helper import switch_mps_device from lama_cleaner.model.controlnet import ControlNet @@ -58,6 +60,7 @@ class ModelManager: raise NotImplementedError(f"Not supported model: {name}") def __call__(self, image, mask, config: Config): + self.switch_controlnet_method(control_method=config.controlnet_method) return self.model(image, mask, config) def switch(self, new_name: str, **kwargs): @@ -86,7 +89,9 @@ class ModelManager: del self.model torch_gc() + old_method = self.kwargs["sd_controlnet_method"] self.kwargs["sd_controlnet_method"] = control_method self.model = self.init_model( self.name, switch_mps_device(self.name, self.device), **self.kwargs ) + logger.info(f"Switch ControlNet method from {old_method} to {control_method}") diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 1a3f48c..cf72d5a 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -40,7 +40,7 @@ def parse_args(): parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) parser.add_argument( "--sd-controlnet-method", - default="control_v11p_sd15_inpaint", + default=DEFAULT_CONTROLNET_METHOD, choices=SD_CONTROLNET_CHOICES, ) parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) diff --git a/lama_cleaner/plugins/__init__.py b/lama_cleaner/plugins/__init__.py index 3c45e16..5a75ea3 100644 --- a/lama_cleaner/plugins/__init__.py +++ b/lama_cleaner/plugins/__init__.py @@ -4,3 +4,4 @@ from .realesrgan import RealESRGANUpscaler from .gfpgan_plugin import GFPGANPlugin from .restoreformer import RestoreFormerPlugin from .gif import MakeGIF +from .anime_seg import AnimeSeg diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 2674316..0be9f47 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -96,4 +96,5 @@ class Config(BaseModel): p2p_guidance_scale: float = 7.5 # ControlNet - controlnet_conditioning_scale: float = 0.4 + controlnet_conditioning_scale: float = 1.0 + controlnet_method: str = "control_v11p_sd15_canny" diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 8d1984c..82e6fc8 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -1,9 +1,6 @@ #!/usr/bin/env python3 -import asyncio -import hashlib import os - -from lama_cleaner.plugins.anime_seg import AnimeSeg +import hashlib os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -32,6 +29,7 @@ from lama_cleaner.plugins import ( MakeGIF, GFPGANPlugin, RestoreFormerPlugin, + AnimeSeg, ) from lama_cleaner.schema import Config @@ -84,7 +82,15 @@ BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") class NoFlaskwebgui(logging.Filter): def filter(self, record): - return "flaskwebgui-keep-server-alive" not in record.getMessage() + msg = record.getMessage() + if "Running on http:" in msg: + print(msg[msg.index("Running on http:") :]) + + return ( + "flaskwebgui-keep-server-alive" not in msg + and "socket.io" not in msg + and "This is a development server." not in msg + ) logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) @@ -92,6 +98,9 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False CORS(app, expose_headers=["Content-Disposition"]) + +sio_logger = logging.getLogger("sio-logger") +sio_logger.setLevel(logging.ERROR) socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading") model: ModelManager = None @@ -254,6 +263,7 @@ def process(): p2p_image_guidance_scale=form["p2pImageGuidanceScale"], p2p_guidance_scale=form["p2pGuidanceScale"], controlnet_conditioning_scale=form["controlnet_conditioning_scale"], + controlnet_method=form["controlnet_method"], ) if config.sd_seed == -1: @@ -263,7 +273,6 @@ def process(): logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) - logger.info(f"Resized image shape: {image.shape}") mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) @@ -436,17 +445,6 @@ def switch_model(): return f"ok, switch to {new_name}", 200 -@app.route("/controlnet_method", methods=["POST"]) -def switch_controlnet_method(): - new_method = request.form.get("method") - - try: - model.switch_controlnet_method(new_method) - except NotImplementedError: - return f"Failed switch to {new_method} not implemented", 500 - return f"Switch to {new_method}", 200 - - @app.route("/") def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) @@ -603,4 +601,10 @@ def main(args): ) ui.run() else: - socketio.run(app, host=args.host, port=args.port, debug=args.debug) + socketio.run( + app, + host=args.host, + port=args.port, + debug=args.debug, + allow_unsafe_werkzeug=True, + ) diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py index d699e1b..224dd73 100644 --- a/lama_cleaner/tests/test_controlnet.py +++ b/lama_cleaner/tests/test_controlnet.py @@ -125,7 +125,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): prompt="a fox sitting on a bench", sd_steps=sd_steps, controlnet_conditioning_scale=1.0, - sd_strength=1.0 + sd_strength=1.0, ) cfg.sd_sampler = sampler @@ -138,3 +138,42 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_controlnet_switch(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_controlnet_method="control_v11p_sd15_canny", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_method="control_v11p_sd15_inpaint", + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_switch_to_inpaint_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py index ba19b40..834c719 100644 --- a/lama_cleaner/tests/test_plugins.py +++ b/lama_cleaner/tests/test_plugins.py @@ -35,6 +35,7 @@ def _save(img, name): def test_remove_bg(): model = RemoveBG() res = model.forward(bgr_img) + res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA) _save(res, "test_remove_bg.png") diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index 0ae9b31..61efdcc 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -16,6 +16,7 @@ def save_config( model, sd_local_model_path, sd_controlnet, + sd_controlnet_method, device, gui, no_gui_auto_close, @@ -182,6 +183,11 @@ def main(config_file: str): sd_controlnet = gr.Checkbox( init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}" ) + sd_controlnet_method = gr.Radio( + SD_CONTROLNET_CHOICES, + lable="ControlNet method", + value=init_config.sd_controlnet_method, + ) no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}") cpu_offload = gr.Checkbox( init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}" @@ -207,6 +213,7 @@ def main(config_file: str): model, sd_local_model_path, sd_controlnet, + sd_controlnet_method, device, gui, no_gui_auto_close, diff --git a/requirements.txt b/requirements.txt index d64b969..2fa4d4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch>=1.9.0 opencv-python flask==2.2.3 flask-socketio +simple-websocket flask_cors flaskwebgui==0.3.5 pydantic