add realesrGAN selection

This commit is contained in:
Qing 2024-02-08 17:16:57 +08:00
parent 8060e16c70
commit f52dbc1091
6 changed files with 98 additions and 9 deletions

View File

@ -43,7 +43,7 @@ from iopaint.helper import (
) )
from iopaint.model.utils import torch_gc from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager from iopaint.model_manager import ModelManager
from iopaint.plugins import build_plugins from iopaint.plugins import build_plugins, RealESRGANUpscaler
from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.remove_bg import RemoveBG from iopaint.plugins.remove_bg import RemoveBG
from iopaint.schema import ( from iopaint.schema import (
@ -59,6 +59,7 @@ from iopaint.schema import (
RemoveBGModel, RemoveBGModel,
SwitchPluginModelRequest, SwitchPluginModelRequest,
ModelInfo, ModelInfo,
RealESRGANModel,
) )
CURRENT_DIR = Path(__file__).parent.absolute().resolve() CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@ -192,6 +193,8 @@ class Api:
self.plugins[req.plugin_name].switch_model(req.model_name) self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name: if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name self.config.remove_bg_model = req.model_name
if req.plugin_name == RealESRGANUpscaler.name:
self.config.realesrgan_model = req.model_name
def api_server_config(self) -> ServerConfigResponse: def api_server_config(self) -> ServerConfigResponse:
plugins = [] plugins = []
@ -209,6 +212,8 @@ class Api:
modelInfos=self.model_manager.scan_models(), modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model, removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(), removeBGModels=RemoveBGModel.values(),
realesrganModel=self.config.realesrgan_model,
realesrganModels=RealESRGANModel.values(),
enableFileManager=self.file_manager is not None, enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None, enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet, enableControlnet=self.model_manager.enable_controlnet,

View File

@ -14,6 +14,12 @@ class RealESRGANUpscaler(BasePlugin):
def __init__(self, name, device, no_half=False): def __init__(self, name, device, no_half=False):
super().__init__() super().__init__()
self.model_name = name
self.device = device
self.no_half = no_half
self._init_model(name)
def _init_model(self, name):
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
@ -70,13 +76,19 @@ class RealESRGANUpscaler(BasePlugin):
scale=model_info["scale"], scale=model_info["scale"],
model_path=model_path, model_path=model_path,
model=model_info["model"](), model=model_info["model"](),
half=True if "cuda" in str(device) and not no_half else False, half=True if "cuda" in str(self.device) and not self.no_half else False,
tile=512, tile=512,
tile_pad=10, tile_pad=10,
pre_pad=10, pre_pad=10,
device=device, device=self.device,
) )
def switch_model(self, new_model_name: str):
if self.model_name == new_model_name:
return
self._init_model(new_model_name)
self.model_name = new_model_name
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")

View File

@ -423,7 +423,9 @@ class ServerConfigResponse(BaseModel):
plugins: List[PluginInfo] plugins: List[PluginInfo]
modelInfos: List[ModelInfo] modelInfos: List[ModelInfo]
removeBGModel: RemoveBGModel removeBGModel: RemoveBGModel
removeBGModels: List[str] removeBGModels: List[RemoveBGModel]
realesrganModel: RealESRGANModel
realesrganModels: List[RealESRGANModel]
enableFileManager: bool enableFileManager: bool
enableAutoSaving: bool enableAutoSaving: bool
enableControlnet: bool enableControlnet: bool

View File

@ -44,7 +44,7 @@ default_configs = dict(
interactive_seg_model=InteractiveSegModel.vit_b, interactive_seg_model=InteractiveSegModel.vit_b,
interactive_seg_device=Device.cpu, interactive_seg_device=Device.cpu,
enable_remove_bg=False, enable_remove_bg=False,
remove_bg_model=RemoveBGModel.u2net, remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
enable_anime_seg=False, enable_anime_seg=False,
enable_realesrgan=False, enable_realesrgan=False,
realesrgan_device=Device.cpu, realesrgan_device=Device.cpu,

View File

@ -57,6 +57,7 @@ const formSchema = z.object({
enableUploadMask: z.boolean(), enableUploadMask: z.boolean(),
enableAutoExtractPrompt: z.boolean(), enableAutoExtractPrompt: z.boolean(),
removeBGModel: z.string(), removeBGModel: z.string(),
realesrganModel: z.string(),
}) })
const TAB_GENERAL = "General" const TAB_GENERAL = "General"
@ -108,6 +109,7 @@ export function SettingsDialog() {
inputDirectory: fileManagerState.inputDirectory, inputDirectory: fileManagerState.inputDirectory,
outputDirectory: fileManagerState.outputDirectory, outputDirectory: fileManagerState.outputDirectory,
removeBGModel: serverConfig?.removeBGModel, removeBGModel: serverConfig?.removeBGModel,
realesrganModel: serverConfig?.realesrganModel,
}, },
}) })
@ -115,6 +117,7 @@ export function SettingsDialog() {
if (serverConfig) { if (serverConfig) {
setServerConfig(serverConfig) setServerConfig(serverConfig)
form.setValue("removeBGModel", serverConfig.removeBGModel) form.setValue("removeBGModel", serverConfig.removeBGModel)
form.setValue("realesrganModel", serverConfig.realesrganModel)
} }
}, [form, serverConfig]) }, [form, serverConfig])
@ -136,7 +139,12 @@ export function SettingsDialog() {
const shouldSwitchModel = model.name !== settings.model.name const shouldSwitchModel = model.name !== settings.model.name
const shouldSwitchRemoveBGModel = const shouldSwitchRemoveBGModel =
serverConfig?.removeBGModel !== values.removeBGModel serverConfig?.removeBGModel !== values.removeBGModel
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel const shouldSwitchRealesrganModel =
serverConfig?.realesrganModel !== values.realesrganModel
const showModelSwitching =
shouldSwitchModel ||
shouldSwitchRemoveBGModel ||
shouldSwitchRealesrganModel
if (showModelSwitching) { if (showModelSwitching) {
const newModelSwitchingTexts: string[] = [] const newModelSwitchingTexts: string[] = []
@ -147,7 +155,12 @@ export function SettingsDialog() {
} }
if (shouldSwitchRemoveBGModel) { if (shouldSwitchRemoveBGModel) {
newModelSwitchingTexts.push( newModelSwitchingTexts.push(
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}` `Switching RemoveBG model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
)
}
if (shouldSwitchRealesrganModel) {
newModelSwitchingTexts.push(
`Switching RealESRGAN model from ${serverConfig?.realesrganModel} to ${values.realesrganModel}`
) )
} }
setModelSwitchingTexts(newModelSwitchingTexts) setModelSwitchingTexts(newModelSwitchingTexts)
@ -182,7 +195,24 @@ export function SettingsDialog() {
} catch (error: any) { } catch (error: any) {
toast({ toast({
variant: "destructive", variant: "destructive",
title: `Switch removebg model to ${model.name} failed: ${error}`, title: `Switch RemoveBG model to ${model.name} failed: ${error}`,
})
}
}
if (shouldSwitchRealesrganModel) {
try {
const res = await switchPluginModel(
PluginName.RealESRGAN,
values.realesrganModel
)
if (res.status !== 200) {
throw new Error(res.statusText)
}
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch RealESRGAN model to ${model.name} failed: ${error}`,
}) })
} }
} }
@ -212,6 +242,9 @@ export function SettingsDialog() {
const removeBGEnabled = plugins.some( const removeBGEnabled = plugins.some(
(plugin) => plugin.name === PluginName.RemoveBG (plugin) => plugin.name === PluginName.RemoveBG
) )
const realesrganEnabled = plugins.some(
(plugin) => plugin.name === PluginName.RealESRGAN
)
function onOpenChange(value: boolean) { function onOpenChange(value: boolean) {
toggleOpen() toggleOpen()
@ -437,7 +470,7 @@ export function SettingsDialog() {
disabled={!removeBGEnabled} disabled={!removeBGEnabled}
> >
<FormControl> <FormControl>
<SelectTrigger className="w-[200px]"> <SelectTrigger className="w-auto">
<SelectValue placeholder="Select removebg model" /> <SelectValue placeholder="Select removebg model" />
</SelectTrigger> </SelectTrigger>
</FormControl> </FormControl>
@ -454,6 +487,41 @@ export function SettingsDialog() {
</FormItem> </FormItem>
)} )}
/> />
<Separator />
<FormField
control={form.control}
name="realesrganModel"
render={({ field }) => (
<FormItem className="flex items-center justify-between">
<div className="space-y-0.5">
<FormLabel>RealESRGAN</FormLabel>
<FormDescription>RealESRGAN Model</FormDescription>
</div>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
disabled={!realesrganEnabled}
>
<FormControl>
<SelectTrigger className="w-auto">
<SelectValue placeholder="Select RealESRGAN model" />
</SelectTrigger>
</FormControl>
<SelectContent align="end">
<SelectGroup>
{serverConfig?.realesrganModels.map((model) => (
<SelectItem key={model} value={model}>
{model}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormItem>
)}
/>
</div> </div>
) )
} }

View File

@ -17,6 +17,8 @@ export interface ServerConfig {
modelInfos: ModelInfo[] modelInfos: ModelInfo[]
removeBGModel: string removeBGModel: string
removeBGModels: string[] removeBGModels: string[]
realesrganModel: string
realesrganModels: string[]
enableFileManager: boolean enableFileManager: boolean
enableAutoSaving: boolean enableAutoSaving: boolean
enableControlnet: boolean enableControlnet: boolean