From a2fd5bb3eaf4f5b44b0854475f3e869df483a26e Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Jan 2024 11:07:35 +0800 Subject: [PATCH] update plugins --- lama_cleaner/api.py | 20 +- lama_cleaner/plugins/anime_seg.py | 3 +- lama_cleaner/plugins/base_plugin.py | 6 +- lama_cleaner/plugins/gfpgan_plugin.py | 3 +- lama_cleaner/plugins/interactive_seg.py | 11 +- lama_cleaner/plugins/realesrgan.py | 8 +- lama_cleaner/plugins/remove_bg.py | 3 +- lama_cleaner/plugins/restoreformer.py | 3 +- lama_cleaner/schema.py | 15 +- lama_cleaner/tests/test_plugins.py | 44 ++-- web_app/src/App.tsx | 2 +- web_app/src/components/Editor.tsx | 3 - web_app/src/components/FileManager.tsx | 20 +- web_app/src/components/Settings.tsx | 32 ++- .../components/SidePanel/DiffusionOptions.tsx | 127 ++++++----- web_app/src/components/Workspace.tsx | 10 +- web_app/src/lib/api.ts | 197 +++++++++--------- web_app/src/lib/states.ts | 42 ++-- web_app/src/lib/types.ts | 15 ++ 19 files changed, 337 insertions(+), 227 deletions(-) diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py index 5afb0b2..4f326a1 100644 --- a/lama_cleaner/api.py +++ b/lama_cleaner/api.py @@ -210,26 +210,26 @@ class Api: ) def api_run_plugin(self, req: RunPluginRequest): + ext = "png" if req.name not in self.plugins: raise HTTPException(status_code=404, detail="Plugin not found") - image, alpha_channel, infos = decode_base64_to_image(req.image) - bgr_res = self.plugins[req.name].run(image, req) + rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + bgr_np_img = self.plugins[req.name](rgb_np_img, req) torch_gc() if req.name == InteractiveSeg.name: return Response( - content=numpy_to_bytes(bgr_res, "png"), - media_type="image/png", + content=numpy_to_bytes(bgr_np_img, ext), + media_type=f"image/{ext}", ) - ext = "png" - if req.name in [RemoveBG.name, AnimeSeg.name]: - rgb_res = bgr_res + if bgr_np_img.shape[2] == 4: + rgba_np_img = bgr_np_img else: - rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB) - rgb_res = concat_alpha_channel(rgb_res, alpha_channel) + rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB) + rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) return Response( content=pil_to_bytes( - Image.fromarray(rgb_res), + Image.fromarray(rgba_np_img), ext=ext, quality=self.config.quality, infos=infos, diff --git a/lama_cleaner/plugins/anime_seg.py b/lama_cleaner/plugins/anime_seg.py index ecfc7d1..bf6bbf2 100644 --- a/lama_cleaner/plugins/anime_seg.py +++ b/lama_cleaner/plugins/anime_seg.py @@ -7,6 +7,7 @@ from PIL import Image from lama_cleaner.helper import load_model from lama_cleaner.plugins.base_plugin import BasePlugin +from lama_cleaner.schema import RunPluginRequest class REBNCONV(nn.Module): @@ -425,7 +426,7 @@ class AnimeSeg(BasePlugin): ANIME_SEG_MODELS["md5"], ) - def __call__(self, rgb_np_img, files, form): + def __call__(self, rgb_np_img, req: RunPluginRequest): return self.forward(rgb_np_img) @torch.no_grad() diff --git a/lama_cleaner/plugins/base_plugin.py b/lama_cleaner/plugins/base_plugin.py index 39d7491..f57e0fc 100644 --- a/lama_cleaner/plugins/base_plugin.py +++ b/lama_cleaner/plugins/base_plugin.py @@ -1,4 +1,7 @@ from loguru import logger +import numpy as np + +from lama_cleaner.schema import RunPluginRequest class BasePlugin: @@ -8,7 +11,8 @@ class BasePlugin: logger.error(err_msg) exit(-1) - def __call__(self, rgb_np_img, files, form): + def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array: + # return RGBA np image or BGR np image ... def check_dep(self): diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py index 2422094..635bcf4 100644 --- a/lama_cleaner/plugins/gfpgan_plugin.py +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -3,6 +3,7 @@ from loguru import logger from lama_cleaner.helper import download_model from lama_cleaner.plugins.base_plugin import BasePlugin +from lama_cleaner.schema import RunPluginRequest class GFPGANPlugin(BasePlugin): @@ -36,7 +37,7 @@ class GFPGANPlugin(BasePlugin): self.face_enhancer.face_helper.face_det.to(device) ) - def __call__(self, rgb_np_img, files, form): + def __call__(self, rgb_np_img, req: RunPluginRequest): weight = 0.5 bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") diff --git a/lama_cleaner/plugins/interactive_seg.py b/lama_cleaner/plugins/interactive_seg.py index 160d59f..7df184d 100644 --- a/lama_cleaner/plugins/interactive_seg.py +++ b/lama_cleaner/plugins/interactive_seg.py @@ -1,4 +1,6 @@ +import hashlib import json +from typing import List import cv2 import numpy as np @@ -7,6 +9,7 @@ from loguru import logger from lama_cleaner.helper import download_model from lama_cleaner.plugins.base_plugin import BasePlugin from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry +from lama_cleaner.schema import RunPluginRequest # 从小到大 SEGMENT_ANYTHING_MODELS = { @@ -44,11 +47,11 @@ class InteractiveSeg(BasePlugin): ) self.prev_img_md5 = None - def __call__(self, rgb_np_img, files, form): - clicks = json.loads(form["clicks"]) - return self.forward(rgb_np_img, clicks, form["img_md5"]) + def __call__(self, rgb_np_img, req: RunPluginRequest): + img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() + return self.forward(rgb_np_img, req.clicks, img_md5) - def forward(self, rgb_np_img, clicks, img_md5): + def forward(self, rgb_np_img, clicks: List[List], img_md5: str): input_point = [] input_label = [] for click in clicks: diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index 36f522a..3c796b3 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -6,6 +6,7 @@ from loguru import logger from lama_cleaner.const import RealESRGANModel from lama_cleaner.helper import download_model from lama_cleaner.plugins.base_plugin import BasePlugin +from lama_cleaner.schema import RunPluginRequest class RealESRGANUpscaler(BasePlugin): @@ -76,11 +77,10 @@ class RealESRGANUpscaler(BasePlugin): device=device, ) - def __call__(self, rgb_np_img, files, form): + def __call__(self, rgb_np_img, req: RunPluginRequest): bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) - scale = float(form["upscale"]) - logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}") - result = self.forward(bgr_np_img, scale) + logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") + result = self.forward(bgr_np_img, req.scale) logger.info(f"RealESRGAN output shape: {result.shape}") return result diff --git a/lama_cleaner/plugins/remove_bg.py b/lama_cleaner/plugins/remove_bg.py index 4025198..31fc565 100644 --- a/lama_cleaner/plugins/remove_bg.py +++ b/lama_cleaner/plugins/remove_bg.py @@ -4,6 +4,7 @@ import numpy as np from torch.hub import get_dir from lama_cleaner.plugins.base_plugin import BasePlugin +from lama_cleaner.schema import RunPluginRequest class RemoveBG(BasePlugin): @@ -19,7 +20,7 @@ class RemoveBG(BasePlugin): self.session = new_session(model_name="u2net") - def __call__(self, rgb_np_img, files, form): + def __call__(self, rgb_np_img, req: RunPluginRequest): bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) return self.forward(bgr_np_img) diff --git a/lama_cleaner/plugins/restoreformer.py b/lama_cleaner/plugins/restoreformer.py index 0cd8b10..3adc441 100644 --- a/lama_cleaner/plugins/restoreformer.py +++ b/lama_cleaner/plugins/restoreformer.py @@ -3,6 +3,7 @@ from loguru import logger from lama_cleaner.helper import download_model from lama_cleaner.plugins.base_plugin import BasePlugin +from lama_cleaner.schema import RunPluginRequest class RestoreFormerPlugin(BasePlugin): @@ -31,7 +32,7 @@ class RestoreFormerPlugin(BasePlugin): bg_upsampler=upscaler.model if upscaler is not None else None, ) - def __call__(self, rgb_np_img, files, form): + def __call__(self, rgb_np_img, req: RunPluginRequest): weight = 0.5 bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}") diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 0cedefd..db67f53 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -136,6 +136,12 @@ class InpaintRequest(BaseModel): extender_height: int = Field(640, description="Extend height for extender") extender_width: int = Field(640, description="Extend width for extender") + sd_scale: float = Field( + 1.0, + description="Resize the image before doing sd inpainting, the area outside the mask will not lose quality.", + gt=0.0, + le=1.0, + ) sd_mask_blur: int = Field( 33, description="Blur the edge of mask area. The higher the number the smoother blend with the original image", @@ -143,6 +149,7 @@ class InpaintRequest(BaseModel): sd_strength: float = Field( 1.0, description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image", + le=1.0, ) sd_steps: int = Field( 50, @@ -202,7 +209,9 @@ class InpaintRequest(BaseModel): # ControlNet enable_controlnet: bool = Field(False, description="Enable controlnet") - controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale") + controlnet_conditioning_scale: float = Field( + 0.4, description="Conditioning scale", gt=0.0, le=1.0 + ) controlnet_method: str = Field( "lllyasviel/control_v11p_sd15_canny", description="Controlnet method" ) @@ -214,6 +223,8 @@ class InpaintRequest(BaseModel): fitting_degree: float = Field( 1.0, description="Control the fitting degree of the generated objects to the mask shape.", + gt=0.0, + le=1.0, ) @field_validator("sd_seed") @@ -226,7 +237,7 @@ class InpaintRequest(BaseModel): class RunPluginRequest(BaseModel): name: str - image: Optional[str] = Field(..., description="base64 encoded image") + image: str = Field(..., description="base64 encoded image") clicks: List[List[int]] = Field( [], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]" ) diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py index 8d782ef..721ad51 100644 --- a/lama_cleaner/tests/test_plugins.py +++ b/lama_cleaner/tests/test_plugins.py @@ -1,8 +1,11 @@ import hashlib import os import time +from PIL import Image +from lama_cleaner.helper import encode_pil_to_base64 from lama_cleaner.plugins.anime_seg import AnimeSeg +from lama_cleaner.schema import RunPluginRequest from lama_cleaner.tests.utils import check_device, current_dir, save_dir os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -22,6 +25,8 @@ img_p = current_dir / "bunny.jpeg" img_bytes = open(img_p, "rb").read() bgr_img = cv2.imread(str(img_p)) rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) +rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {}) +bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {}) def _save(img, name): @@ -30,15 +35,18 @@ def _save(img, name): def test_remove_bg(): model = RemoveBG() - res = model.forward(bgr_img) - res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA) + rgba_np_img = model( + rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) + ) + res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA) _save(res, "test_remove_bg.png") def test_anime_seg(): model = AnimeSeg() img = cv2.imread(str(current_dir / "anime_test.png")) - res = model.forward(img) + img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {}) + res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) assert len(res.shape) == 3 assert res.shape[-1] == 4 _save(res, "test_anime_seg.png") @@ -48,10 +56,16 @@ def test_anime_seg(): def test_upscale(device): check_device(device) model = RealESRGANUpscaler("realesr-general-x4v3", device) - res = model.forward(bgr_img, 2) + res = model( + rgb_img, + RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2), + ) _save(res, f"test_upscale_x2_{device}.png") - res = model.forward(bgr_img, 4) + res = model( + rgb_img, + RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4), + ) _save(res, f"test_upscale_x4_{device}.png") @@ -59,7 +73,7 @@ def test_upscale(device): def test_gfpgan(device): check_device(device) model = GFPGANPlugin(device) - res = model(rgb_img, None, None) + res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)) _save(res, f"test_gfpgan_{device}.png") @@ -67,20 +81,24 @@ def test_gfpgan(device): def test_restoreformer(device): check_device(device) model = RestoreFormerPlugin(device) - res = model(rgb_img, None, None) + res = model( + rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64) + ) _save(res, f"test_restoreformer_{device}.png") @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) def test_segment_anything(device): check_device(device) - img_md5 = hashlib.md5(img_bytes).hexdigest() model = InteractiveSeg("vit_l", device) - new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5) + new_mask = model( + rgb_img, + RunPluginRequest( + name=InteractiveSeg.name, + image=rgb_img_base64, + clicks=([[448 // 2, 394 // 2, 1]]), + ), + ) save_name = f"test_segment_anything_{device}.png" _save(new_mask, save_name) - - start = time.time() - model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5) - print(f"Time for {save_name}: {time.time() - start:.2f}s") diff --git a/web_app/src/App.tsx b/web_app/src/App.tsx index e1560f2..75431ee 100644 --- a/web_app/src/App.tsx +++ b/web_app/src/App.tsx @@ -42,7 +42,7 @@ function Home() { useEffect(() => { const fetchServerConfig = async () => { - const serverConfig = await getServerConfig().then((res) => res.json()) + const serverConfig = await getServerConfig() setServerConfig(serverConfig) if (serverConfig.isDesktop) { // Keeping GUI Window Open diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 2703d38..5393593 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -363,9 +363,6 @@ export default function Editor(props: EditorProps) { undefined, newClicks ) - if (!res) { - throw new Error("Something went wrong on server side.") - } const { blob } = res const img = new Image() img.onload = () => { diff --git a/web_app/src/components/FileManager.tsx b/web_app/src/components/FileManager.tsx index 68f7fb7..84e04ed 100644 --- a/web_app/src/components/FileManager.tsx +++ b/web_app/src/components/FileManager.tsx @@ -78,6 +78,7 @@ export default function FileManager(props: Props) { const ref = useRef(null) const debouncedSearchText = useDebounce(fileManagerState.searchText, 300) const [tab, setTab] = useState(IMAGE_TAB) + const [filenames, setFilenames] = useState([]) const [photos, setPhotos] = useState([]) const [photoIndex, setPhotoIndex] = useState(0) @@ -131,13 +132,28 @@ export default function FileManager(props: Props) { [open, closeScrollTop] ) + useEffect(() => { + const fetchData = async () => { + try { + const filenames = await getMedias(tab) + setFilenames(filenames) + } catch (e: any) { + toast({ + variant: "destructive", + title: "Uh oh! Something went wrong.", + description: e.message ? e.message : e.toString(), + }) + } + } + fetchData() + }, [tab]) + useEffect(() => { if (!open) { return } const fetchData = async () => { try { - const filenames = await getMedias(tab) let filteredFilenames = filenames if (debouncedSearchText) { const fuse = new Fuse(filteredFilenames, { @@ -173,7 +189,7 @@ export default function FileManager(props: Props) { } } fetchData() - }, [tab, debouncedSearchText, fileManagerState, photoWidth, open]) + }, [filenames, debouncedSearchText, fileManagerState, photoWidth, open]) const onScroll = (event: SyntheticEvent) => { setScrollTop(event.currentTarget.scrollTop) diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index 38843c1..ff1c692 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -99,7 +99,7 @@ export function SettingsDialog() { }, }) - function onSubmit(values: z.infer) { + async function onSubmit(values: z.infer) { // Do something with the form values. ✅ This will be type-safe and validated. updateSettings({ enableDownloadMask: values.enableDownloadMask, @@ -116,24 +116,22 @@ export function SettingsDialog() { if (model.name !== settings.model.name) { toggleOpenModelSwitching() updateAppState({ disableShortCuts: true }) - switchModel(model.name) - .then((res) => { - toast({ - title: `Switch to ${model.name} success`, - }) - setAppModel(model) + try { + const newModel = await switchModel(model.name) + toast({ + title: `Switch to ${newModel.name} success`, }) - .catch((error: any) => { - toast({ - variant: "destructive", - title: `Switch to ${model.name} failed: ${error}`, - }) - setModel(settings.model) - }) - .finally(() => { - toggleOpenModelSwitching() - updateAppState({ disableShortCuts: false }) + setAppModel(model) + } catch (error: any) { + toast({ + variant: "destructive", + title: `Switch to ${model.name} failed: ${error}`, }) + setModel(settings.model) + } finally { + toggleOpenModelSwitching() + updateAppState({ disableShortCuts: false }) + } } } diff --git a/web_app/src/components/SidePanel/DiffusionOptions.tsx b/web_app/src/components/SidePanel/DiffusionOptions.tsx index e29e395..a2c9d1d 100644 --- a/web_app/src/components/SidePanel/DiffusionOptions.tsx +++ b/web_app/src/components/SidePanel/DiffusionOptions.tsx @@ -69,6 +69,27 @@ const DiffusionOptions = () => { } } + const renderCropper = () => { + return ( + + + { + updateSettings({ showCropper: value }) + if (value) { + updateSettings({ showExtender: false }) + } + }} + /> + + ) + } + const renderConterNetSetting = () => { if (!settings.model.support_controlnet) { return null @@ -558,28 +579,8 @@ const DiffusionOptions = () => { ) } - return ( -
- - - { - updateSettings({ showCropper: value }) - if (value) { - updateSettings({ showExtender: false }) - } - }} - /> - - - {renderExtender()} - {renderPowerPaintTaskType()} - + const renderSteps = () => { + return (
{ />
+ ) + } + const renderGuidanceScale = () => { + return (
{ />
+ ) + } - {renderP2PImageGuidanceScale()} - {renderStrength()} - + const renderSampler = () => { + return ( + ) + } + const renderSeed = () => { + return ( {/* 每次会从服务器返回更新该值 */} { />
+ ) + } - {renderNegativePrompt()} - - - - {renderConterNetSetting()} - {renderFreeu()} - {renderLCMLora()} - + const renderMaskBlur = () => { + return (
{ />
+ ) + } - - - { - updateSettings({ sdMatchHistograms: value }) - }} - /> - + const renderMatchHistograms = () => { + return ( + <> + + + { + updateSettings({ sdMatchHistograms: value }) + }} + /> + + + + ) + } + return ( +
+ {renderCropper()} + {renderExtender()} + {renderPowerPaintTaskType()} + {renderSteps()} + {renderGuidanceScale()} + {renderP2PImageGuidanceScale()} + {renderStrength()} + {renderSampler()} + {renderSeed()} + {renderNegativePrompt()} - + {renderConterNetSetting()} + {renderLCMLora()} + {renderMaskBlur()} + {renderMatchHistograms()} + {renderFreeu()} {renderPaintByExample()}
) diff --git a/web_app/src/components/Workspace.tsx b/web_app/src/components/Workspace.tsx index 909bb10..bef1fda 100644 --- a/web_app/src/components/Workspace.tsx +++ b/web_app/src/components/Workspace.tsx @@ -14,11 +14,11 @@ const Workspace = () => { ]) useEffect(() => { - currentModel() - .then((res) => res.json()) - .then((model) => { - updateSettings({ model }) - }) + const fetchCurrentModel = async () => { + const model = await currentModel() + updateSettings({ model }) + } + fetchCurrentModel() }, []) return ( diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 3fb20f5..d3a6fa9 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -1,4 +1,11 @@ -import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types" +import { + Filename, + GenInfo, + ModelInfo, + PowerPaintTask, + Rect, + ServerConfig, +} from "@/lib/types" import { Settings } from "@/lib/states" import { convertToBase64, srcToFile } from "@/lib/utils" import axios from "axios" @@ -24,124 +31,113 @@ export default async function inpaint( const exampleImageBase64 = paintByExampleImage ? await convertToBase64(paintByExampleImage) : null - try { - const res = await fetch(`${API_ENDPOINT}/inpaint`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - image: imageBase64, - mask: maskBase64, - ldm_steps: settings.ldmSteps, - ldm_sampler: settings.ldmSampler, - zits_wireframe: settings.zitsWireframe, - cv2_flag: settings.cv2Flag, - cv2_radius: settings.cv2Radius, - hd_strategy: "Crop", - hd_strategy_crop_triger_size: 640, - hd_strategy_crop_margin: 128, - hd_trategy_resize_imit: 2048, - prompt: settings.prompt, - negative_prompt: settings.negativePrompt, - use_croper: settings.showCropper, - croper_x: croperRect.x, - croper_y: croperRect.y, - croper_height: croperRect.height, - croper_width: croperRect.width, - use_extender: settings.showExtender, - extender_x: extenderState.x, - extender_y: extenderState.y, - extender_height: extenderState.height, - extender_width: extenderState.width, - sd_mask_blur: settings.sdMaskBlur, - sd_strength: settings.sdStrength, - sd_steps: settings.sdSteps, - sd_guidance_scale: settings.sdGuidanceScale, - sd_sampler: settings.sdSampler, - sd_seed: settings.seedFixed ? settings.seed : -1, - sd_match_histograms: settings.sdMatchHistograms, - sd_freeu: settings.enableFreeu, - sd_freeu_config: settings.freeuConfig, - sd_lcm_lora: settings.enableLCMLora, - paint_by_example_example_image: exampleImageBase64, - p2p_image_guidance_scale: settings.p2pImageGuidanceScale, - enable_controlnet: settings.enableControlnet, - controlnet_conditioning_scale: settings.controlnetConditioningScale, - controlnet_method: settings.controlnetMethod - ? settings.controlnetMethod - : "", - powerpaint_task: settings.showExtender - ? PowerPaintTask.outpainting - : settings.powerpaintTask, - }), - }) + + const res = await fetch(`${API_ENDPOINT}/inpaint`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + image: imageBase64, + mask: maskBase64, + ldm_steps: settings.ldmSteps, + ldm_sampler: settings.ldmSampler, + zits_wireframe: settings.zitsWireframe, + cv2_flag: settings.cv2Flag, + cv2_radius: settings.cv2Radius, + hd_strategy: "Crop", + hd_strategy_crop_triger_size: 640, + hd_strategy_crop_margin: 128, + hd_trategy_resize_imit: 2048, + prompt: settings.prompt, + negative_prompt: settings.negativePrompt, + use_croper: settings.showCropper, + croper_x: croperRect.x, + croper_y: croperRect.y, + croper_height: croperRect.height, + croper_width: croperRect.width, + use_extender: settings.showExtender, + extender_x: extenderState.x, + extender_y: extenderState.y, + extender_height: extenderState.height, + extender_width: extenderState.width, + sd_mask_blur: settings.sdMaskBlur, + sd_strength: settings.sdStrength, + sd_steps: settings.sdSteps, + sd_guidance_scale: settings.sdGuidanceScale, + sd_sampler: settings.sdSampler, + sd_seed: settings.seedFixed ? settings.seed : -1, + sd_match_histograms: settings.sdMatchHistograms, + sd_freeu: settings.enableFreeu, + sd_freeu_config: settings.freeuConfig, + sd_lcm_lora: settings.enableLCMLora, + paint_by_example_example_image: exampleImageBase64, + p2p_image_guidance_scale: settings.p2pImageGuidanceScale, + enable_controlnet: settings.enableControlnet, + controlnet_conditioning_scale: settings.controlnetConditioningScale, + controlnet_method: settings.controlnetMethod + ? settings.controlnetMethod + : "", + powerpaint_task: settings.showExtender + ? PowerPaintTask.outpainting + : settings.powerpaintTask, + }), + }) + if (res.ok) { const blob = await res.blob() return { blob: URL.createObjectURL(blob), seed: res.headers.get("X-Seed"), } - } catch (error: any) { - throw new Error(`Something went wrong: ${JSON.stringify(error.message)}`) } + const errors = await res.json() + throw new Error(`Something went wrong: ${errors.errors}`) } -export function getServerConfig() { - return fetch(`${API_ENDPOINT}/server-config`, { - method: "GET", - }) +export async function getServerConfig(): Promise { + const res = await api.get(`/server-config`) + return res.data } -export function switchModel(name: string) { - return axios.post(`${API_ENDPOINT}/model`, { name }) +export async function switchModel(name: string): Promise { + const res = await api.post(`/model`, { name }) + return res.data } -export function currentModel() { - return fetch(`${API_ENDPOINT}/model`, { - method: "GET", - }) +export async function currentModel(): Promise { + const res = await api.get("/model") + return res.data } export function fetchModelInfos(): Promise { return api.get("/models").then((response) => response.data) } -export function modelDownloaded(name: string) { - return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, { - method: "GET", - }) -} - export async function runPlugin( name: string, imageFile: File, upscale?: number, clicks?: number[][] ) { - const fd = new FormData() - fd.append("name", name) - fd.append("image", imageFile) - if (upscale) { - fd.append("upscale", upscale.toString()) - } - if (clicks) { - fd.append("clicks", JSON.stringify(clicks)) - } - - try { - const res = await fetch(`${API_ENDPOINT}/run_plugin`, { - method: "POST", - body: fd, - }) - if (res.ok) { - const blob = await res.blob() - return { blob: URL.createObjectURL(blob) } - } - const errMsg = await res.text() - throw new Error(errMsg) - } catch (error) { - throw new Error(`Something went wrong: ${error}`) + const imageBase64 = await convertToBase64(imageFile) + const res = await fetch(`${API_ENDPOINT}/run_plugin`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + name, + image: imageBase64, + upscale, + clicks, + }), + }) + if (res.ok) { + const blob = await res.blob() + return { blob: URL.createObjectURL(blob) } } + const errMsg = await res.json() + throw new Error(errMsg) } export async function getMediaFile(tab: string, filename: string) { @@ -160,12 +156,12 @@ export async function getMediaFile(tab: string, filename: string) { }) return file } - const errMsg = await res.text() - throw new Error(errMsg) + const errMsg = await res.json() + throw new Error(errMsg.errors) } export async function getMedias(tab: string): Promise { - const res = await axios.get(`${API_ENDPOINT}/medias`, { params: { tab } }) + const res = await api.get(`medias`, { params: { tab } }) return res.data } @@ -191,3 +187,10 @@ export async function downloadToOutput( throw new Error(`Something went wrong: ${error}`) } } + +export async function getGenInfo(file: File): Promise { + const fd = new FormData() + fd.append("file", file) + const res = await api.post(`/gen-info`, fd) + return res.data +} diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index 5cfbf8e..7ed45c7 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -15,6 +15,7 @@ import { Point, PowerPaintTask, SDSampler, + ServerConfig, Size, SortBy, SortOrder, @@ -33,7 +34,7 @@ import { loadImage, srcToFile, } from "./utils" -import inpaint, { runPlugin } from "./api" +import inpaint, { getGenInfo, runPlugin } from "./api" import { toast } from "@/components/ui/use-toast" type FileManagerState = { @@ -57,6 +58,7 @@ export type Settings = { enableDownloadMask: boolean enableManualInpainting: boolean enableUploadMask: boolean + enableAutoExtractPrompt: boolean showCropper: boolean showExtender: boolean extenderDirection: ExtenderDirection @@ -103,16 +105,6 @@ export type Settings = { powerpaintTask: PowerPaintTask } -type ServerConfig = { - plugins: string[] - enableFileManager: boolean - enableAutoSaving: boolean - enableControlnet: boolean - controlnetMethod: string - disableModelSwitch: boolean - isDesktop: boolean -} - type InteractiveSegState = { isInteractiveSeg: boolean interactiveSegMask: HTMLImageElement | null @@ -162,7 +154,7 @@ type AppState = { type AppAction = { updateAppState: (newState: Partial) => void - setFile: (file: File) => void + setFile: (file: File) => Promise setCustomFile: (file: File) => void setIsInpainting: (newValue: boolean) => void setIsPluginRunning: (newValue: boolean) => void @@ -304,6 +296,7 @@ const defaultValues: AppState = { enableDownloadMask: false, enableManualInpainting: false, enableUploadMask: false, + enableAutoExtractPrompt: true, ldmSteps: 30, ldmSampler: LDMSampler.ddim, zitsWireframe: true, @@ -540,9 +533,6 @@ export const useStore = createWithEqualityFn()( const start = new Date() const targetFile = await get().getCurrentTargetFile() const res = await runPlugin(pluginName, targetFile, params.upscale) - if (!res) { - throw new Error("Something went wrong on server side.") - } const { blob } = res const newRender = new Image() await loadImage(newRender, blob) @@ -818,7 +808,27 @@ export const useStore = createWithEqualityFn()( state.isPluginRunning = newValue }), - setFile: (file: File) => { + setFile: async (file: File) => { + if (get().settings.enableAutoExtractPrompt) { + try { + const res = await getGenInfo(file) + if (res.prompt) { + set((state) => { + state.settings.prompt = res.prompt + }) + } + if (res.negative_prompt) { + set((state) => { + state.settings.negativePrompt = res.negative_prompt + }) + } + } catch (e: any) { + toast({ + variant: "destructive", + description: e.message ? e.message : e.toString(), + }) + } + } set((state) => { state.file = file state.interactiveSegState = castDraft( diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index 9755d78..8b7f769 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -6,6 +6,21 @@ export interface Filename { mtime: number } +export interface ServerConfig { + plugins: string[] + enableFileManager: boolean + enableAutoSaving: boolean + enableControlnet: boolean + controlnetMethod: string + disableModelSwitch: boolean + isDesktop: boolean +} + +export interface GenInfo { + prompt: string + negative_prompt: string +} + export interface ModelInfo { name: string path: string