From ec2db92ad950f2873e71edebb78c901777f554da Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 10 Feb 2024 12:34:56 +0800 Subject: [PATCH] add switch interactiveSegModel --- iopaint/api.py | 43 +++------------ iopaint/plugins/interactive_seg.py | 17 +++++- iopaint/schema.py | 2 + web_app/src/components/Settings.tsx | 81 +++++++++++++++++++++++++++-- web_app/src/lib/types.ts | 2 + 5 files changed, 103 insertions(+), 42 deletions(-) diff --git a/iopaint/api.py b/iopaint/api.py index cace639..2633741 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -43,7 +43,7 @@ from iopaint.helper import ( ) from iopaint.model.utils import torch_gc from iopaint.model_manager import ModelManager -from iopaint.plugins import build_plugins, RealESRGANUpscaler +from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.remove_bg import RemoveBG from iopaint.schema import ( @@ -59,6 +59,7 @@ from iopaint.schema import ( RemoveBGModel, SwitchPluginModelRequest, ModelInfo, + InteractiveSegModel, RealESRGANModel, ) @@ -202,6 +203,9 @@ class Api: self.config.remove_bg_model = req.model_name if req.plugin_name == RealESRGANUpscaler.name: self.config.realesrgan_model = req.model_name + if req.plugin_name == InteractiveSeg.name: + self.config.interactive_seg_model = req.model_name + torch_gc() def api_server_config(self) -> ServerConfigResponse: plugins = [] @@ -221,6 +225,8 @@ class Api: removeBGModels=RemoveBGModel.values(), realesrganModel=self.config.realesrgan_model, realesrganModels=RealESRGANModel.values(), + interactiveSegModel=self.config.interactive_seg_model, + interactiveSegModels=InteractiveSegModel.values(), enableFileManager=self.file_manager is not None, enableAutoSaving=self.config.output_dir is not None, enableControlnet=self.model_manager.enable_controlnet, @@ -388,38 +394,3 @@ class Api: cpu_offload=self.config.cpu_offload, callback=diffuser_callback, ) - - -if __name__ == "__main__": - from iopaint.schema import InteractiveSegModel, RealESRGANModel - - app = FastAPI() - api = Api( - app, - ApiConfig( - host="127.0.0.1", - port=8080, - model="lama", - no_half=False, - cpu_offload=False, - disable_nsfw_checker=False, - cpu_textencoder=False, - device="cpu", - input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images", - output_dir="/Users/cwq/code/github/lama-cleaner/tmp", - quality=100, - enable_interactive_seg=False, - interactive_seg_model=InteractiveSegModel.vit_b, - interactive_seg_device="cpu", - enable_remove_bg=False, - enable_anime_seg=False, - enable_realesrgan=False, - realesrgan_device="cpu", - realesrgan_model=RealESRGANModel.realesr_general_x4v3, - enable_gfpgan=False, - gfpgan_device="cpu", - enable_restoreformer=False, - restoreformer_device="cpu", - ), - ) - api.launch() diff --git a/iopaint/plugins/interactive_seg.py b/iopaint/plugins/interactive_seg.py index 16b819d..f19a3a8 100644 --- a/iopaint/plugins/interactive_seg.py +++ b/iopaint/plugins/interactive_seg.py @@ -37,16 +37,31 @@ class InteractiveSeg(BasePlugin): def __init__(self, model_name, device): super().__init__() + self.model_name = model_name + self.device = device + self._init_session(model_name) + + def _init_session(self, model_name: str): model_path = download_model( SEGMENT_ANYTHING_MODELS[model_name]["url"], SEGMENT_ANYTHING_MODELS[model_name]["md5"], ) logger.info(f"SegmentAnything model path: {model_path}") self.predictor = SamPredictor( - sam_model_registry[model_name](checkpoint=model_path).to(device) + sam_model_registry[model_name](checkpoint=model_path).to(self.device) ) self.prev_img_md5 = None + def switch_model(self, new_model_name): + if self.model_name == new_model_name: + return + + logger.info( + f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}" + ) + self._init_session(new_model_name) + self.model_name = new_model_name + def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() return self.forward(rgb_np_img, req.clicks, img_md5) diff --git a/iopaint/schema.py b/iopaint/schema.py index b7c8b44..3e31d77 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -427,6 +427,8 @@ class ServerConfigResponse(BaseModel): removeBGModels: List[RemoveBGModel] realesrganModel: RealESRGANModel realesrganModels: List[RealESRGANModel] + interactiveSegModel: InteractiveSegModel + interactiveSegModels: List[InteractiveSegModel] enableFileManager: bool enableAutoSaving: bool enableControlnet: bool diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index 18a021c..79f08bd 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -58,6 +58,7 @@ const formSchema = z.object({ enableAutoExtractPrompt: z.boolean(), removeBGModel: z.string(), realesrganModel: z.string(), + interactiveSegModel: z.string(), }) const TAB_GENERAL = "General" @@ -110,6 +111,7 @@ export function SettingsDialog() { outputDirectory: fileManagerState.outputDirectory, removeBGModel: serverConfig?.removeBGModel, realesrganModel: serverConfig?.realesrganModel, + interactiveSegModel: serverConfig?.interactiveSegModel, }, }) @@ -118,6 +120,7 @@ export function SettingsDialog() { setServerConfig(serverConfig) form.setValue("removeBGModel", serverConfig.removeBGModel) form.setValue("realesrganModel", serverConfig.realesrganModel) + form.setValue("interactiveSegModel", serverConfig.interactiveSegModel) } }, [form, serverConfig]) @@ -138,13 +141,19 @@ export function SettingsDialog() { const shouldSwitchModel = model.name !== settings.model.name const shouldSwitchRemoveBGModel = - serverConfig?.removeBGModel !== values.removeBGModel + serverConfig?.removeBGModel !== values.removeBGModel && removeBGEnabled const shouldSwitchRealesrganModel = - serverConfig?.realesrganModel !== values.realesrganModel + serverConfig?.realesrganModel !== values.realesrganModel && + realesrganEnabled + const shouldSwitchInteractiveModel = + serverConfig?.interactiveSegModel !== values.interactiveSegModel && + interactiveSegEnabled + const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel || - shouldSwitchRealesrganModel + shouldSwitchRealesrganModel || + shouldSwitchInteractiveModel if (showModelSwitching) { const newModelSwitchingTexts: string[] = [] @@ -163,6 +172,11 @@ export function SettingsDialog() { `Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}` ) } + if (shouldSwitchInteractiveModel) { + newModelSwitchingTexts.push( + `Switching ${PluginName.InteractiveSeg} model from ${serverConfig?.interactiveSegModel} to ${values.interactiveSegModel}` + ) + } setModelSwitchingTexts(newModelSwitchingTexts) updateAppState({ disableShortCuts: true }) @@ -195,7 +209,7 @@ export function SettingsDialog() { } catch (error: any) { toast({ variant: "destructive", - title: `Switch RemoveBG model to ${model.name} failed: ${error}`, + title: `Switch RemoveBG model to ${values.removeBGModel} failed: ${error}`, }) } } @@ -212,7 +226,24 @@ export function SettingsDialog() { } catch (error: any) { toast({ variant: "destructive", - title: `Switch RealESRGAN model to ${model.name} failed: ${error}`, + title: `Switch RealESRGAN model to ${values.realesrganModel} failed: ${error}`, + }) + } + } + + if (shouldSwitchInteractiveModel) { + try { + const res = await switchPluginModel( + PluginName.InteractiveSeg, + values.interactiveSegModel + ) + if (res.status !== 200) { + throw new Error(res.statusText) + } + } catch (error: any) { + toast({ + variant: "destructive", + title: `Switch ${PluginName.InteractiveSeg} model to ${values.interactiveSegModel} failed: ${error}`, }) } } @@ -245,6 +276,9 @@ export function SettingsDialog() { const realesrganEnabled = plugins.some( (plugin) => plugin.name === PluginName.RealESRGAN ) + const interactiveSegEnabled = plugins.some( + (plugin) => plugin.name === PluginName.InteractiveSeg + ) function onOpenChange(value: boolean) { toggleOpen() @@ -522,6 +556,43 @@ export function SettingsDialog() { )} /> + + + + ( + +
+ Interactive Segmentation + + Interactive Segmentation Model + +
+ +
+ )} + /> ) } diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index 9452c66..38cacc0 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -19,6 +19,8 @@ export interface ServerConfig { removeBGModels: string[] realesrganModel: string realesrganModels: string[] + interactiveSegModel: string + interactiveSegModels: string[] enableFileManager: boolean enableAutoSaving: boolean enableControlnet: boolean