add switch interactiveSegModel

This commit is contained in:
Qing 2024-02-10 12:34:56 +08:00
parent 9aa5a7e0ba
commit ec2db92ad9
5 changed files with 103 additions and 42 deletions

View File

@ -43,7 +43,7 @@ from iopaint.helper import (
) )
from iopaint.model.utils import torch_gc from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager from iopaint.model_manager import ModelManager
from iopaint.plugins import build_plugins, RealESRGANUpscaler from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.remove_bg import RemoveBG from iopaint.plugins.remove_bg import RemoveBG
from iopaint.schema import ( from iopaint.schema import (
@ -59,6 +59,7 @@ from iopaint.schema import (
RemoveBGModel, RemoveBGModel,
SwitchPluginModelRequest, SwitchPluginModelRequest,
ModelInfo, ModelInfo,
InteractiveSegModel,
RealESRGANModel, RealESRGANModel,
) )
@ -202,6 +203,9 @@ class Api:
self.config.remove_bg_model = req.model_name self.config.remove_bg_model = req.model_name
if req.plugin_name == RealESRGANUpscaler.name: if req.plugin_name == RealESRGANUpscaler.name:
self.config.realesrgan_model = req.model_name self.config.realesrgan_model = req.model_name
if req.plugin_name == InteractiveSeg.name:
self.config.interactive_seg_model = req.model_name
torch_gc()
def api_server_config(self) -> ServerConfigResponse: def api_server_config(self) -> ServerConfigResponse:
plugins = [] plugins = []
@ -221,6 +225,8 @@ class Api:
removeBGModels=RemoveBGModel.values(), removeBGModels=RemoveBGModel.values(),
realesrganModel=self.config.realesrgan_model, realesrganModel=self.config.realesrgan_model,
realesrganModels=RealESRGANModel.values(), realesrganModels=RealESRGANModel.values(),
interactiveSegModel=self.config.interactive_seg_model,
interactiveSegModels=InteractiveSegModel.values(),
enableFileManager=self.file_manager is not None, enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None, enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet, enableControlnet=self.model_manager.enable_controlnet,
@ -388,38 +394,3 @@ class Api:
cpu_offload=self.config.cpu_offload, cpu_offload=self.config.cpu_offload,
callback=diffuser_callback, callback=diffuser_callback,
) )
if __name__ == "__main__":
from iopaint.schema import InteractiveSegModel, RealESRGANModel
app = FastAPI()
api = Api(
app,
ApiConfig(
host="127.0.0.1",
port=8080,
model="lama",
no_half=False,
cpu_offload=False,
disable_nsfw_checker=False,
cpu_textencoder=False,
device="cpu",
input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images",
output_dir="/Users/cwq/code/github/lama-cleaner/tmp",
quality=100,
enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device="cpu",
enable_remove_bg=False,
enable_anime_seg=False,
enable_realesrgan=False,
realesrgan_device="cpu",
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
enable_gfpgan=False,
gfpgan_device="cpu",
enable_restoreformer=False,
restoreformer_device="cpu",
),
)
api.launch()

View File

@ -37,16 +37,31 @@ class InteractiveSeg(BasePlugin):
def __init__(self, model_name, device): def __init__(self, model_name, device):
super().__init__() super().__init__()
self.model_name = model_name
self.device = device
self._init_session(model_name)
def _init_session(self, model_name: str):
model_path = download_model( model_path = download_model(
SEGMENT_ANYTHING_MODELS[model_name]["url"], SEGMENT_ANYTHING_MODELS[model_name]["url"],
SEGMENT_ANYTHING_MODELS[model_name]["md5"], SEGMENT_ANYTHING_MODELS[model_name]["md5"],
) )
logger.info(f"SegmentAnything model path: {model_path}") logger.info(f"SegmentAnything model path: {model_path}")
self.predictor = SamPredictor( self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(device) sam_model_registry[model_name](checkpoint=model_path).to(self.device)
) )
self.prev_img_md5 = None self.prev_img_md5 = None
def switch_model(self, new_model_name):
if self.model_name == new_model_name:
return
logger.info(
f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
)
self._init_session(new_model_name)
self.model_name = new_model_name
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5) return self.forward(rgb_np_img, req.clicks, img_md5)

View File

@ -427,6 +427,8 @@ class ServerConfigResponse(BaseModel):
removeBGModels: List[RemoveBGModel] removeBGModels: List[RemoveBGModel]
realesrganModel: RealESRGANModel realesrganModel: RealESRGANModel
realesrganModels: List[RealESRGANModel] realesrganModels: List[RealESRGANModel]
interactiveSegModel: InteractiveSegModel
interactiveSegModels: List[InteractiveSegModel]
enableFileManager: bool enableFileManager: bool
enableAutoSaving: bool enableAutoSaving: bool
enableControlnet: bool enableControlnet: bool

View File

@ -58,6 +58,7 @@ const formSchema = z.object({
enableAutoExtractPrompt: z.boolean(), enableAutoExtractPrompt: z.boolean(),
removeBGModel: z.string(), removeBGModel: z.string(),
realesrganModel: z.string(), realesrganModel: z.string(),
interactiveSegModel: z.string(),
}) })
const TAB_GENERAL = "General" const TAB_GENERAL = "General"
@ -110,6 +111,7 @@ export function SettingsDialog() {
outputDirectory: fileManagerState.outputDirectory, outputDirectory: fileManagerState.outputDirectory,
removeBGModel: serverConfig?.removeBGModel, removeBGModel: serverConfig?.removeBGModel,
realesrganModel: serverConfig?.realesrganModel, realesrganModel: serverConfig?.realesrganModel,
interactiveSegModel: serverConfig?.interactiveSegModel,
}, },
}) })
@ -118,6 +120,7 @@ export function SettingsDialog() {
setServerConfig(serverConfig) setServerConfig(serverConfig)
form.setValue("removeBGModel", serverConfig.removeBGModel) form.setValue("removeBGModel", serverConfig.removeBGModel)
form.setValue("realesrganModel", serverConfig.realesrganModel) form.setValue("realesrganModel", serverConfig.realesrganModel)
form.setValue("interactiveSegModel", serverConfig.interactiveSegModel)
} }
}, [form, serverConfig]) }, [form, serverConfig])
@ -138,13 +141,19 @@ export function SettingsDialog() {
const shouldSwitchModel = model.name !== settings.model.name const shouldSwitchModel = model.name !== settings.model.name
const shouldSwitchRemoveBGModel = const shouldSwitchRemoveBGModel =
serverConfig?.removeBGModel !== values.removeBGModel serverConfig?.removeBGModel !== values.removeBGModel && removeBGEnabled
const shouldSwitchRealesrganModel = const shouldSwitchRealesrganModel =
serverConfig?.realesrganModel !== values.realesrganModel serverConfig?.realesrganModel !== values.realesrganModel &&
realesrganEnabled
const shouldSwitchInteractiveModel =
serverConfig?.interactiveSegModel !== values.interactiveSegModel &&
interactiveSegEnabled
const showModelSwitching = const showModelSwitching =
shouldSwitchModel || shouldSwitchModel ||
shouldSwitchRemoveBGModel || shouldSwitchRemoveBGModel ||
shouldSwitchRealesrganModel shouldSwitchRealesrganModel ||
shouldSwitchInteractiveModel
if (showModelSwitching) { if (showModelSwitching) {
const newModelSwitchingTexts: string[] = [] const newModelSwitchingTexts: string[] = []
@ -163,6 +172,11 @@ export function SettingsDialog() {
`Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}` `Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}`
) )
} }
if (shouldSwitchInteractiveModel) {
newModelSwitchingTexts.push(
`Switching ${PluginName.InteractiveSeg} model from ${serverConfig?.interactiveSegModel} to ${values.interactiveSegModel}`
)
}
setModelSwitchingTexts(newModelSwitchingTexts) setModelSwitchingTexts(newModelSwitchingTexts)
updateAppState({ disableShortCuts: true }) updateAppState({ disableShortCuts: true })
@ -195,7 +209,7 @@ export function SettingsDialog() {
} catch (error: any) { } catch (error: any) {
toast({ toast({
variant: "destructive", variant: "destructive",
title: `Switch RemoveBG model to ${model.name} failed: ${error}`, title: `Switch RemoveBG model to ${values.removeBGModel} failed: ${error}`,
}) })
} }
} }
@ -212,7 +226,24 @@ export function SettingsDialog() {
} catch (error: any) { } catch (error: any) {
toast({ toast({
variant: "destructive", variant: "destructive",
title: `Switch RealESRGAN model to ${model.name} failed: ${error}`, title: `Switch RealESRGAN model to ${values.realesrganModel} failed: ${error}`,
})
}
}
if (shouldSwitchInteractiveModel) {
try {
const res = await switchPluginModel(
PluginName.InteractiveSeg,
values.interactiveSegModel
)
if (res.status !== 200) {
throw new Error(res.statusText)
}
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch ${PluginName.InteractiveSeg} model to ${values.interactiveSegModel} failed: ${error}`,
}) })
} }
} }
@ -245,6 +276,9 @@ export function SettingsDialog() {
const realesrganEnabled = plugins.some( const realesrganEnabled = plugins.some(
(plugin) => plugin.name === PluginName.RealESRGAN (plugin) => plugin.name === PluginName.RealESRGAN
) )
const interactiveSegEnabled = plugins.some(
(plugin) => plugin.name === PluginName.InteractiveSeg
)
function onOpenChange(value: boolean) { function onOpenChange(value: boolean) {
toggleOpen() toggleOpen()
@ -522,6 +556,43 @@ export function SettingsDialog() {
</FormItem> </FormItem>
)} )}
/> />
<Separator />
<FormField
control={form.control}
name="interactiveSegModel"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>Interactive Segmentation</FormLabel>
<FormDescription>
Interactive Segmentation Model
</FormDescription>
</div>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!interactiveSegEnabled}
>
<FormControl>
<SelectTrigger className="w-auto">
<SelectValue placeholder="Select interactive segmentation model" />
</SelectTrigger>
</FormControl>
<SelectContent align="end">
<SelectGroup>
{serverConfig?.interactiveSegModels.map((model) => (
<SelectItem key={model} value={model}>
{model}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormItem>
)}
/>
</div> </div>
) )
} }

View File

@ -19,6 +19,8 @@ export interface ServerConfig {
removeBGModels: string[] removeBGModels: string[]
realesrganModel: string realesrganModel: string
realesrganModels: string[] realesrganModels: string[]
interactiveSegModel: string
interactiveSegModels: string[]
enableFileManager: boolean enableFileManager: boolean
enableAutoSaving: boolean enableAutoSaving: boolean
enableControlnet: boolean enableControlnet: boolean