From f52dbc10914312fb2f5ff5ed6915363abaacdf22 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 8 Feb 2024 17:16:57 +0800 Subject: [PATCH] add realesrGAN selection --- iopaint/api.py | 7 ++- iopaint/plugins/realesrgan.py | 16 +++++- iopaint/schema.py | 4 +- iopaint/web_config.py | 2 +- web_app/src/components/Settings.tsx | 76 +++++++++++++++++++++++++++-- web_app/src/lib/types.ts | 2 + 6 files changed, 98 insertions(+), 9 deletions(-) diff --git a/iopaint/api.py b/iopaint/api.py index ab54cfe..8fa7bfd 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -43,7 +43,7 @@ from iopaint.helper import ( ) from iopaint.model.utils import torch_gc from iopaint.model_manager import ModelManager -from iopaint.plugins import build_plugins +from iopaint.plugins import build_plugins, RealESRGANUpscaler from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.remove_bg import RemoveBG from iopaint.schema import ( @@ -59,6 +59,7 @@ from iopaint.schema import ( RemoveBGModel, SwitchPluginModelRequest, ModelInfo, + RealESRGANModel, ) CURRENT_DIR = Path(__file__).parent.absolute().resolve() @@ -192,6 +193,8 @@ class Api: self.plugins[req.plugin_name].switch_model(req.model_name) if req.plugin_name == RemoveBG.name: self.config.remove_bg_model = req.model_name + if req.plugin_name == RealESRGANUpscaler.name: + self.config.realesrgan_model = req.model_name def api_server_config(self) -> ServerConfigResponse: plugins = [] @@ -209,6 +212,8 @@ class Api: modelInfos=self.model_manager.scan_models(), removeBGModel=self.config.remove_bg_model, removeBGModels=RemoveBGModel.values(), + realesrganModel=self.config.realesrgan_model, + realesrganModels=RealESRGANModel.values(), enableFileManager=self.file_manager is not None, enableAutoSaving=self.config.output_dir is not None, enableControlnet=self.model_manager.enable_controlnet, diff --git a/iopaint/plugins/realesrgan.py b/iopaint/plugins/realesrgan.py index 8165fa3..5275700 100644 --- a/iopaint/plugins/realesrgan.py +++ b/iopaint/plugins/realesrgan.py @@ -14,6 +14,12 @@ class RealESRGANUpscaler(BasePlugin): def __init__(self, name, device, no_half=False): super().__init__() + self.model_name = name + self.device = device + self.no_half = no_half + self._init_model(name) + + def _init_model(self, name): from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact @@ -70,13 +76,19 @@ class RealESRGANUpscaler(BasePlugin): scale=model_info["scale"], model_path=model_path, model=model_info["model"](), - half=True if "cuda" in str(device) and not no_half else False, + half=True if "cuda" in str(self.device) and not self.no_half else False, tile=512, tile_pad=10, pre_pad=10, - device=device, + device=self.device, ) + def switch_model(self, new_model_name: str): + if self.model_name == new_model_name: + return + self._init_model(new_model_name) + self.model_name = new_model_name + def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") diff --git a/iopaint/schema.py b/iopaint/schema.py index 91fe7c0..a8972ed 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -423,7 +423,9 @@ class ServerConfigResponse(BaseModel): plugins: List[PluginInfo] modelInfos: List[ModelInfo] removeBGModel: RemoveBGModel - removeBGModels: List[str] + removeBGModels: List[RemoveBGModel] + realesrganModel: RealESRGANModel + realesrganModels: List[RealESRGANModel] enableFileManager: bool enableAutoSaving: bool enableControlnet: bool diff --git a/iopaint/web_config.py b/iopaint/web_config.py index 4f71b24..4e34c14 100644 --- a/iopaint/web_config.py +++ b/iopaint/web_config.py @@ -44,7 +44,7 @@ default_configs = dict( interactive_seg_model=InteractiveSegModel.vit_b, interactive_seg_device=Device.cpu, enable_remove_bg=False, - remove_bg_model=RemoveBGModel.u2net, + remove_bg_model=RemoveBGModel.briaai_rmbg_1_4, enable_anime_seg=False, enable_realesrgan=False, realesrgan_device=Device.cpu, diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index 8ba5514..5138e03 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -57,6 +57,7 @@ const formSchema = z.object({ enableUploadMask: z.boolean(), enableAutoExtractPrompt: z.boolean(), removeBGModel: z.string(), + realesrganModel: z.string(), }) const TAB_GENERAL = "General" @@ -108,6 +109,7 @@ export function SettingsDialog() { inputDirectory: fileManagerState.inputDirectory, outputDirectory: fileManagerState.outputDirectory, removeBGModel: serverConfig?.removeBGModel, + realesrganModel: serverConfig?.realesrganModel, }, }) @@ -115,6 +117,7 @@ export function SettingsDialog() { if (serverConfig) { setServerConfig(serverConfig) form.setValue("removeBGModel", serverConfig.removeBGModel) + form.setValue("realesrganModel", serverConfig.realesrganModel) } }, [form, serverConfig]) @@ -136,7 +139,12 @@ export function SettingsDialog() { const shouldSwitchModel = model.name !== settings.model.name const shouldSwitchRemoveBGModel = serverConfig?.removeBGModel !== values.removeBGModel - const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel + const shouldSwitchRealesrganModel = + serverConfig?.realesrganModel !== values.realesrganModel + const showModelSwitching = + shouldSwitchModel || + shouldSwitchRemoveBGModel || + shouldSwitchRealesrganModel if (showModelSwitching) { const newModelSwitchingTexts: string[] = [] @@ -147,7 +155,12 @@ export function SettingsDialog() { } if (shouldSwitchRemoveBGModel) { newModelSwitchingTexts.push( - `Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}` + `Switching RemoveBG model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}` + ) + } + if (shouldSwitchRealesrganModel) { + newModelSwitchingTexts.push( + `Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}` ) } setModelSwitchingTexts(newModelSwitchingTexts) @@ -182,7 +195,24 @@ export function SettingsDialog() { } catch (error: any) { toast({ variant: "destructive", - title: `Switch removebg model to ${model.name} failed: ${error}`, + title: `Switch RemoveBG model to ${model.name} failed: ${error}`, + }) + } + } + + if (shouldSwitchRealesrganModel) { + try { + const res = await switchPluginModel( + PluginName.RealESRGAN, + values.realesrganModel + ) + if (res.status !== 200) { + throw new Error(res.statusText) + } + } catch (error: any) { + toast({ + variant: "destructive", + title: `Switch RealESRGAN model to ${model.name} failed: ${error}`, }) } } @@ -212,6 +242,9 @@ export function SettingsDialog() { const removeBGEnabled = plugins.some( (plugin) => plugin.name === PluginName.RemoveBG ) + const realesrganEnabled = plugins.some( + (plugin) => plugin.name === PluginName.RealESRGAN + ) function onOpenChange(value: boolean) { toggleOpen() @@ -437,7 +470,7 @@ export function SettingsDialog() { disabled={!removeBGEnabled} > - + @@ -454,6 +487,41 @@ export function SettingsDialog() { )} /> + + + + ( + +
+ RealESRGAN + RealESRGAN Model +
+ +
+ )} + /> ) } diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index 69799de..9452c66 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -17,6 +17,8 @@ export interface ServerConfig { modelInfos: ModelInfo[] removeBGModel: string removeBGModels: string[] + realesrganModel: string + realesrganModels: string[] enableFileManager: boolean enableAutoSaving: boolean enableControlnet: boolean