add switch interactiveSegModel
This commit is contained in:
parent
9aa5a7e0ba
commit
ec2db92ad9
@ -43,7 +43,7 @@ from iopaint.helper import (
|
|||||||
)
|
)
|
||||||
from iopaint.model.utils import torch_gc
|
from iopaint.model.utils import torch_gc
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.plugins import build_plugins, RealESRGANUpscaler
|
from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
|
||||||
from iopaint.plugins.base_plugin import BasePlugin
|
from iopaint.plugins.base_plugin import BasePlugin
|
||||||
from iopaint.plugins.remove_bg import RemoveBG
|
from iopaint.plugins.remove_bg import RemoveBG
|
||||||
from iopaint.schema import (
|
from iopaint.schema import (
|
||||||
@ -59,6 +59,7 @@ from iopaint.schema import (
|
|||||||
RemoveBGModel,
|
RemoveBGModel,
|
||||||
SwitchPluginModelRequest,
|
SwitchPluginModelRequest,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
InteractiveSegModel,
|
||||||
RealESRGANModel,
|
RealESRGANModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,6 +203,9 @@ class Api:
|
|||||||
self.config.remove_bg_model = req.model_name
|
self.config.remove_bg_model = req.model_name
|
||||||
if req.plugin_name == RealESRGANUpscaler.name:
|
if req.plugin_name == RealESRGANUpscaler.name:
|
||||||
self.config.realesrgan_model = req.model_name
|
self.config.realesrgan_model = req.model_name
|
||||||
|
if req.plugin_name == InteractiveSeg.name:
|
||||||
|
self.config.interactive_seg_model = req.model_name
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
def api_server_config(self) -> ServerConfigResponse:
|
def api_server_config(self) -> ServerConfigResponse:
|
||||||
plugins = []
|
plugins = []
|
||||||
@ -221,6 +225,8 @@ class Api:
|
|||||||
removeBGModels=RemoveBGModel.values(),
|
removeBGModels=RemoveBGModel.values(),
|
||||||
realesrganModel=self.config.realesrgan_model,
|
realesrganModel=self.config.realesrgan_model,
|
||||||
realesrganModels=RealESRGANModel.values(),
|
realesrganModels=RealESRGANModel.values(),
|
||||||
|
interactiveSegModel=self.config.interactive_seg_model,
|
||||||
|
interactiveSegModels=InteractiveSegModel.values(),
|
||||||
enableFileManager=self.file_manager is not None,
|
enableFileManager=self.file_manager is not None,
|
||||||
enableAutoSaving=self.config.output_dir is not None,
|
enableAutoSaving=self.config.output_dir is not None,
|
||||||
enableControlnet=self.model_manager.enable_controlnet,
|
enableControlnet=self.model_manager.enable_controlnet,
|
||||||
@ -388,38 +394,3 @@ class Api:
|
|||||||
cpu_offload=self.config.cpu_offload,
|
cpu_offload=self.config.cpu_offload,
|
||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from iopaint.schema import InteractiveSegModel, RealESRGANModel
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
api = Api(
|
|
||||||
app,
|
|
||||||
ApiConfig(
|
|
||||||
host="127.0.0.1",
|
|
||||||
port=8080,
|
|
||||||
model="lama",
|
|
||||||
no_half=False,
|
|
||||||
cpu_offload=False,
|
|
||||||
disable_nsfw_checker=False,
|
|
||||||
cpu_textencoder=False,
|
|
||||||
device="cpu",
|
|
||||||
input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images",
|
|
||||||
output_dir="/Users/cwq/code/github/lama-cleaner/tmp",
|
|
||||||
quality=100,
|
|
||||||
enable_interactive_seg=False,
|
|
||||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
|
||||||
interactive_seg_device="cpu",
|
|
||||||
enable_remove_bg=False,
|
|
||||||
enable_anime_seg=False,
|
|
||||||
enable_realesrgan=False,
|
|
||||||
realesrgan_device="cpu",
|
|
||||||
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
|
||||||
enable_gfpgan=False,
|
|
||||||
gfpgan_device="cpu",
|
|
||||||
enable_restoreformer=False,
|
|
||||||
restoreformer_device="cpu",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
api.launch()
|
|
||||||
|
@ -37,16 +37,31 @@ class InteractiveSeg(BasePlugin):
|
|||||||
|
|
||||||
def __init__(self, model_name, device):
|
def __init__(self, model_name, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.model_name = model_name
|
||||||
|
self.device = device
|
||||||
|
self._init_session(model_name)
|
||||||
|
|
||||||
|
def _init_session(self, model_name: str):
|
||||||
model_path = download_model(
|
model_path = download_model(
|
||||||
SEGMENT_ANYTHING_MODELS[model_name]["url"],
|
SEGMENT_ANYTHING_MODELS[model_name]["url"],
|
||||||
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
|
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
|
||||||
)
|
)
|
||||||
logger.info(f"SegmentAnything model path: {model_path}")
|
logger.info(f"SegmentAnything model path: {model_path}")
|
||||||
self.predictor = SamPredictor(
|
self.predictor = SamPredictor(
|
||||||
sam_model_registry[model_name](checkpoint=model_path).to(device)
|
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||||
)
|
)
|
||||||
self.prev_img_md5 = None
|
self.prev_img_md5 = None
|
||||||
|
|
||||||
|
def switch_model(self, new_model_name):
|
||||||
|
if self.model_name == new_model_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
|
||||||
|
)
|
||||||
|
self._init_session(new_model_name)
|
||||||
|
self.model_name = new_model_name
|
||||||
|
|
||||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||||
|
@ -427,6 +427,8 @@ class ServerConfigResponse(BaseModel):
|
|||||||
removeBGModels: List[RemoveBGModel]
|
removeBGModels: List[RemoveBGModel]
|
||||||
realesrganModel: RealESRGANModel
|
realesrganModel: RealESRGANModel
|
||||||
realesrganModels: List[RealESRGANModel]
|
realesrganModels: List[RealESRGANModel]
|
||||||
|
interactiveSegModel: InteractiveSegModel
|
||||||
|
interactiveSegModels: List[InteractiveSegModel]
|
||||||
enableFileManager: bool
|
enableFileManager: bool
|
||||||
enableAutoSaving: bool
|
enableAutoSaving: bool
|
||||||
enableControlnet: bool
|
enableControlnet: bool
|
||||||
|
@ -58,6 +58,7 @@ const formSchema = z.object({
|
|||||||
enableAutoExtractPrompt: z.boolean(),
|
enableAutoExtractPrompt: z.boolean(),
|
||||||
removeBGModel: z.string(),
|
removeBGModel: z.string(),
|
||||||
realesrganModel: z.string(),
|
realesrganModel: z.string(),
|
||||||
|
interactiveSegModel: z.string(),
|
||||||
})
|
})
|
||||||
|
|
||||||
const TAB_GENERAL = "General"
|
const TAB_GENERAL = "General"
|
||||||
@ -110,6 +111,7 @@ export function SettingsDialog() {
|
|||||||
outputDirectory: fileManagerState.outputDirectory,
|
outputDirectory: fileManagerState.outputDirectory,
|
||||||
removeBGModel: serverConfig?.removeBGModel,
|
removeBGModel: serverConfig?.removeBGModel,
|
||||||
realesrganModel: serverConfig?.realesrganModel,
|
realesrganModel: serverConfig?.realesrganModel,
|
||||||
|
interactiveSegModel: serverConfig?.interactiveSegModel,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -118,6 +120,7 @@ export function SettingsDialog() {
|
|||||||
setServerConfig(serverConfig)
|
setServerConfig(serverConfig)
|
||||||
form.setValue("removeBGModel", serverConfig.removeBGModel)
|
form.setValue("removeBGModel", serverConfig.removeBGModel)
|
||||||
form.setValue("realesrganModel", serverConfig.realesrganModel)
|
form.setValue("realesrganModel", serverConfig.realesrganModel)
|
||||||
|
form.setValue("interactiveSegModel", serverConfig.interactiveSegModel)
|
||||||
}
|
}
|
||||||
}, [form, serverConfig])
|
}, [form, serverConfig])
|
||||||
|
|
||||||
@ -138,13 +141,19 @@ export function SettingsDialog() {
|
|||||||
|
|
||||||
const shouldSwitchModel = model.name !== settings.model.name
|
const shouldSwitchModel = model.name !== settings.model.name
|
||||||
const shouldSwitchRemoveBGModel =
|
const shouldSwitchRemoveBGModel =
|
||||||
serverConfig?.removeBGModel !== values.removeBGModel
|
serverConfig?.removeBGModel !== values.removeBGModel && removeBGEnabled
|
||||||
const shouldSwitchRealesrganModel =
|
const shouldSwitchRealesrganModel =
|
||||||
serverConfig?.realesrganModel !== values.realesrganModel
|
serverConfig?.realesrganModel !== values.realesrganModel &&
|
||||||
|
realesrganEnabled
|
||||||
|
const shouldSwitchInteractiveModel =
|
||||||
|
serverConfig?.interactiveSegModel !== values.interactiveSegModel &&
|
||||||
|
interactiveSegEnabled
|
||||||
|
|
||||||
const showModelSwitching =
|
const showModelSwitching =
|
||||||
shouldSwitchModel ||
|
shouldSwitchModel ||
|
||||||
shouldSwitchRemoveBGModel ||
|
shouldSwitchRemoveBGModel ||
|
||||||
shouldSwitchRealesrganModel
|
shouldSwitchRealesrganModel ||
|
||||||
|
shouldSwitchInteractiveModel
|
||||||
|
|
||||||
if (showModelSwitching) {
|
if (showModelSwitching) {
|
||||||
const newModelSwitchingTexts: string[] = []
|
const newModelSwitchingTexts: string[] = []
|
||||||
@ -163,6 +172,11 @@ export function SettingsDialog() {
|
|||||||
`Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}`
|
`Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}`
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if (shouldSwitchInteractiveModel) {
|
||||||
|
newModelSwitchingTexts.push(
|
||||||
|
`Switching ${PluginName.InteractiveSeg} model from ${serverConfig?.interactiveSegModel} to ${values.interactiveSegModel}`
|
||||||
|
)
|
||||||
|
}
|
||||||
setModelSwitchingTexts(newModelSwitchingTexts)
|
setModelSwitchingTexts(newModelSwitchingTexts)
|
||||||
|
|
||||||
updateAppState({ disableShortCuts: true })
|
updateAppState({ disableShortCuts: true })
|
||||||
@ -195,7 +209,7 @@ export function SettingsDialog() {
|
|||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
toast({
|
toast({
|
||||||
variant: "destructive",
|
variant: "destructive",
|
||||||
title: `Switch RemoveBG model to ${model.name} failed: ${error}`,
|
title: `Switch RemoveBG model to ${values.removeBGModel} failed: ${error}`,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -212,7 +226,24 @@ export function SettingsDialog() {
|
|||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
toast({
|
toast({
|
||||||
variant: "destructive",
|
variant: "destructive",
|
||||||
title: `Switch RealESRGAN model to ${model.name} failed: ${error}`,
|
title: `Switch RealESRGAN model to ${values.realesrganModel} failed: ${error}`,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldSwitchInteractiveModel) {
|
||||||
|
try {
|
||||||
|
const res = await switchPluginModel(
|
||||||
|
PluginName.InteractiveSeg,
|
||||||
|
values.interactiveSegModel
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(res.statusText)
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
toast({
|
||||||
|
variant: "destructive",
|
||||||
|
title: `Switch ${PluginName.InteractiveSeg} model to ${values.interactiveSegModel} failed: ${error}`,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -245,6 +276,9 @@ export function SettingsDialog() {
|
|||||||
const realesrganEnabled = plugins.some(
|
const realesrganEnabled = plugins.some(
|
||||||
(plugin) => plugin.name === PluginName.RealESRGAN
|
(plugin) => plugin.name === PluginName.RealESRGAN
|
||||||
)
|
)
|
||||||
|
const interactiveSegEnabled = plugins.some(
|
||||||
|
(plugin) => plugin.name === PluginName.InteractiveSeg
|
||||||
|
)
|
||||||
|
|
||||||
function onOpenChange(value: boolean) {
|
function onOpenChange(value: boolean) {
|
||||||
toggleOpen()
|
toggleOpen()
|
||||||
@ -522,6 +556,43 @@ export function SettingsDialog() {
|
|||||||
</FormItem>
|
</FormItem>
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<Separator />
|
||||||
|
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="interactiveSegModel"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem className="flex items-center justify-between">
|
||||||
|
<div className="space-y-0.5">
|
||||||
|
<FormLabel>Interactive Segmentation</FormLabel>
|
||||||
|
<FormDescription>
|
||||||
|
Interactive Segmentation Model
|
||||||
|
</FormDescription>
|
||||||
|
</div>
|
||||||
|
<Select
|
||||||
|
onValueChange={field.onChange}
|
||||||
|
defaultValue={field.value}
|
||||||
|
disabled={!interactiveSegEnabled}
|
||||||
|
>
|
||||||
|
<FormControl>
|
||||||
|
<SelectTrigger className="w-auto">
|
||||||
|
<SelectValue placeholder="Select interactive segmentation model" />
|
||||||
|
</SelectTrigger>
|
||||||
|
</FormControl>
|
||||||
|
<SelectContent align="end">
|
||||||
|
<SelectGroup>
|
||||||
|
{serverConfig?.interactiveSegModels.map((model) => (
|
||||||
|
<SelectItem key={model} value={model}>
|
||||||
|
{model}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,8 @@ export interface ServerConfig {
|
|||||||
removeBGModels: string[]
|
removeBGModels: string[]
|
||||||
realesrganModel: string
|
realesrganModel: string
|
||||||
realesrganModels: string[]
|
realesrganModels: string[]
|
||||||
|
interactiveSegModel: string
|
||||||
|
interactiveSegModels: string[]
|
||||||
enableFileManager: boolean
|
enableFileManager: boolean
|
||||||
enableAutoSaving: boolean
|
enableAutoSaving: boolean
|
||||||
enableControlnet: boolean
|
enableControlnet: boolean
|
||||||
|
Loading…
Reference in New Issue
Block a user