add realesrGAN selection
This commit is contained in:
parent
8060e16c70
commit
f52dbc1091
@ -43,7 +43,7 @@ from iopaint.helper import (
|
||||
)
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.plugins import build_plugins
|
||||
from iopaint.plugins import build_plugins, RealESRGANUpscaler
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.remove_bg import RemoveBG
|
||||
from iopaint.schema import (
|
||||
@ -59,6 +59,7 @@ from iopaint.schema import (
|
||||
RemoveBGModel,
|
||||
SwitchPluginModelRequest,
|
||||
ModelInfo,
|
||||
RealESRGANModel,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
@ -192,6 +193,8 @@ class Api:
|
||||
self.plugins[req.plugin_name].switch_model(req.model_name)
|
||||
if req.plugin_name == RemoveBG.name:
|
||||
self.config.remove_bg_model = req.model_name
|
||||
if req.plugin_name == RealESRGANUpscaler.name:
|
||||
self.config.realesrgan_model = req.model_name
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
plugins = []
|
||||
@ -209,6 +212,8 @@ class Api:
|
||||
modelInfos=self.model_manager.scan_models(),
|
||||
removeBGModel=self.config.remove_bg_model,
|
||||
removeBGModels=RemoveBGModel.values(),
|
||||
realesrganModel=self.config.realesrgan_model,
|
||||
realesrganModels=RealESRGANModel.values(),
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
|
@ -14,6 +14,12 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
|
||||
def __init__(self, name, device, no_half=False):
|
||||
super().__init__()
|
||||
self.model_name = name
|
||||
self.device = device
|
||||
self.no_half = no_half
|
||||
self._init_model(name)
|
||||
|
||||
def _init_model(self, name):
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
@ -70,13 +76,19 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
scale=model_info["scale"],
|
||||
model_path=model_path,
|
||||
model=model_info["model"](),
|
||||
half=True if "cuda" in str(device) and not no_half else False,
|
||||
half=True if "cuda" in str(self.device) and not self.no_half else False,
|
||||
tile=512,
|
||||
tile_pad=10,
|
||||
pre_pad=10,
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def switch_model(self, new_model_name: str):
|
||||
if self.model_name == new_model_name:
|
||||
return
|
||||
self._init_model(new_model_name)
|
||||
self.model_name = new_model_name
|
||||
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||
|
@ -423,7 +423,9 @@ class ServerConfigResponse(BaseModel):
|
||||
plugins: List[PluginInfo]
|
||||
modelInfos: List[ModelInfo]
|
||||
removeBGModel: RemoveBGModel
|
||||
removeBGModels: List[str]
|
||||
removeBGModels: List[RemoveBGModel]
|
||||
realesrganModel: RealESRGANModel
|
||||
realesrganModels: List[RealESRGANModel]
|
||||
enableFileManager: bool
|
||||
enableAutoSaving: bool
|
||||
enableControlnet: bool
|
||||
|
@ -44,7 +44,7 @@ default_configs = dict(
|
||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||
interactive_seg_device=Device.cpu,
|
||||
enable_remove_bg=False,
|
||||
remove_bg_model=RemoveBGModel.u2net,
|
||||
remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
|
||||
enable_anime_seg=False,
|
||||
enable_realesrgan=False,
|
||||
realesrgan_device=Device.cpu,
|
||||
|
@ -57,6 +57,7 @@ const formSchema = z.object({
|
||||
enableUploadMask: z.boolean(),
|
||||
enableAutoExtractPrompt: z.boolean(),
|
||||
removeBGModel: z.string(),
|
||||
realesrganModel: z.string(),
|
||||
})
|
||||
|
||||
const TAB_GENERAL = "General"
|
||||
@ -108,6 +109,7 @@ export function SettingsDialog() {
|
||||
inputDirectory: fileManagerState.inputDirectory,
|
||||
outputDirectory: fileManagerState.outputDirectory,
|
||||
removeBGModel: serverConfig?.removeBGModel,
|
||||
realesrganModel: serverConfig?.realesrganModel,
|
||||
},
|
||||
})
|
||||
|
||||
@ -115,6 +117,7 @@ export function SettingsDialog() {
|
||||
if (serverConfig) {
|
||||
setServerConfig(serverConfig)
|
||||
form.setValue("removeBGModel", serverConfig.removeBGModel)
|
||||
form.setValue("realesrganModel", serverConfig.realesrganModel)
|
||||
}
|
||||
}, [form, serverConfig])
|
||||
|
||||
@ -136,7 +139,12 @@ export function SettingsDialog() {
|
||||
const shouldSwitchModel = model.name !== settings.model.name
|
||||
const shouldSwitchRemoveBGModel =
|
||||
serverConfig?.removeBGModel !== values.removeBGModel
|
||||
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
|
||||
const shouldSwitchRealesrganModel =
|
||||
serverConfig?.realesrganModel !== values.realesrganModel
|
||||
const showModelSwitching =
|
||||
shouldSwitchModel ||
|
||||
shouldSwitchRemoveBGModel ||
|
||||
shouldSwitchRealesrganModel
|
||||
|
||||
if (showModelSwitching) {
|
||||
const newModelSwitchingTexts: string[] = []
|
||||
@ -147,7 +155,12 @@ export function SettingsDialog() {
|
||||
}
|
||||
if (shouldSwitchRemoveBGModel) {
|
||||
newModelSwitchingTexts.push(
|
||||
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
|
||||
`Switching RemoveBG model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
|
||||
)
|
||||
}
|
||||
if (shouldSwitchRealesrganModel) {
|
||||
newModelSwitchingTexts.push(
|
||||
`Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}`
|
||||
)
|
||||
}
|
||||
setModelSwitchingTexts(newModelSwitchingTexts)
|
||||
@ -182,7 +195,24 @@ export function SettingsDialog() {
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch removebg model to ${model.name} failed: ${error}`,
|
||||
title: `Switch RemoveBG model to ${model.name} failed: ${error}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldSwitchRealesrganModel) {
|
||||
try {
|
||||
const res = await switchPluginModel(
|
||||
PluginName.RealESRGAN,
|
||||
values.realesrganModel
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(res.statusText)
|
||||
}
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch RealESRGAN model to ${model.name} failed: ${error}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -212,6 +242,9 @@ export function SettingsDialog() {
|
||||
const removeBGEnabled = plugins.some(
|
||||
(plugin) => plugin.name === PluginName.RemoveBG
|
||||
)
|
||||
const realesrganEnabled = plugins.some(
|
||||
(plugin) => plugin.name === PluginName.RealESRGAN
|
||||
)
|
||||
|
||||
function onOpenChange(value: boolean) {
|
||||
toggleOpen()
|
||||
@ -437,7 +470,7 @@ export function SettingsDialog() {
|
||||
disabled={!removeBGEnabled}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-[200px]">
|
||||
<SelectTrigger className="w-auto">
|
||||
<SelectValue placeholder="Select removebg model" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
@ -454,6 +487,41 @@ export function SettingsDialog() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<Separator />
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="realesrganModel"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center justify-between">
|
||||
<div className="space-y-0.5">
|
||||
<FormLabel>RealESRGAN</FormLabel>
|
||||
<FormDescription>RealESRGAN Model</FormDescription>
|
||||
</div>
|
||||
<Select
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
disabled={!realesrganEnabled}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-auto">
|
||||
<SelectValue placeholder="Select RealESRGAN model" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent align="end">
|
||||
<SelectGroup>
|
||||
{serverConfig?.realesrganModels.map((model) => (
|
||||
<SelectItem key={model} value={model}>
|
||||
{model}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@ -17,6 +17,8 @@ export interface ServerConfig {
|
||||
modelInfos: ModelInfo[]
|
||||
removeBGModel: string
|
||||
removeBGModels: string[]
|
||||
realesrganModel: string
|
||||
realesrganModels: string[]
|
||||
enableFileManager: boolean
|
||||
enableAutoSaving: boolean
|
||||
enableControlnet: boolean
|
||||
|
Loading…
Reference in New Issue
Block a user