update plugins

This commit is contained in:
Qing 2024-01-02 11:07:35 +08:00
parent b0e009f879
commit a2fd5bb3ea
19 changed files with 337 additions and 227 deletions

View File

@ -210,26 +210,26 @@ class Api:
)
def api_run_plugin(self, req: RunPluginRequest):
ext = "png"
if req.name not in self.plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
image, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_res = self.plugins[req.name].run(image, req)
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
torch_gc()
if req.name == InteractiveSeg.name:
return Response(
content=numpy_to_bytes(bgr_res, "png"),
media_type="image/png",
content=numpy_to_bytes(bgr_np_img, ext),
media_type=f"image/{ext}",
)
ext = "png"
if req.name in [RemoveBG.name, AnimeSeg.name]:
rgb_res = bgr_res
if bgr_np_img.shape[2] == 4:
rgba_np_img = bgr_np_img
else:
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
rgb_res = concat_alpha_channel(rgb_res, alpha_channel)
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
return Response(
content=pil_to_bytes(
Image.fromarray(rgb_res),
Image.fromarray(rgba_np_img),
ext=ext,
quality=self.config.quality,
infos=infos,

View File

@ -7,6 +7,7 @@ from PIL import Image
from lama_cleaner.helper import load_model
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import RunPluginRequest
class REBNCONV(nn.Module):
@ -425,7 +426,7 @@ class AnimeSeg(BasePlugin):
ANIME_SEG_MODELS["md5"],
)
def __call__(self, rgb_np_img, files, form):
def __call__(self, rgb_np_img, req: RunPluginRequest):
return self.forward(rgb_np_img)
@torch.no_grad()

View File

@ -1,4 +1,7 @@
from loguru import logger
import numpy as np
from lama_cleaner.schema import RunPluginRequest
class BasePlugin:
@ -8,7 +11,8 @@ class BasePlugin:
logger.error(err_msg)
exit(-1)
def __call__(self, rgb_np_img, files, form):
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
# return RGBA np image or BGR np image
...
def check_dep(self):

View File

@ -3,6 +3,7 @@ from loguru import logger
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import RunPluginRequest
class GFPGANPlugin(BasePlugin):
@ -36,7 +37,7 @@ class GFPGANPlugin(BasePlugin):
self.face_enhancer.face_helper.face_det.to(device)
)
def __call__(self, rgb_np_img, files, form):
def __call__(self, rgb_np_img, req: RunPluginRequest):
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")

View File

@ -1,4 +1,6 @@
import hashlib
import json
from typing import List
import cv2
import numpy as np
@ -7,6 +9,7 @@ from loguru import logger
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
from lama_cleaner.schema import RunPluginRequest
# 从小到大
SEGMENT_ANYTHING_MODELS = {
@ -44,11 +47,11 @@ class InteractiveSeg(BasePlugin):
)
self.prev_img_md5 = None
def __call__(self, rgb_np_img, files, form):
clicks = json.loads(form["clicks"])
return self.forward(rgb_np_img, clicks, form["img_md5"])
def __call__(self, rgb_np_img, req: RunPluginRequest):
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5)
def forward(self, rgb_np_img, clicks, img_md5):
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
input_point = []
input_label = []
for click in clicks:

View File

@ -6,6 +6,7 @@ from loguru import logger
from lama_cleaner.const import RealESRGANModel
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import RunPluginRequest
class RealESRGANUpscaler(BasePlugin):
@ -76,11 +77,10 @@ class RealESRGANUpscaler(BasePlugin):
device=device,
)
def __call__(self, rgb_np_img, files, form):
def __call__(self, rgb_np_img, req: RunPluginRequest):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
scale = float(form["upscale"])
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
result = self.forward(bgr_np_img, scale)
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
result = self.forward(bgr_np_img, req.scale)
logger.info(f"RealESRGAN output shape: {result.shape}")
return result

View File

@ -4,6 +4,7 @@ import numpy as np
from torch.hub import get_dir
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import RunPluginRequest
class RemoveBG(BasePlugin):
@ -19,7 +20,7 @@ class RemoveBG(BasePlugin):
self.session = new_session(model_name="u2net")
def __call__(self, rgb_np_img, files, form):
def __call__(self, rgb_np_img, req: RunPluginRequest):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
return self.forward(bgr_np_img)

View File

@ -3,6 +3,7 @@ from loguru import logger
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import RunPluginRequest
class RestoreFormerPlugin(BasePlugin):
@ -31,7 +32,7 @@ class RestoreFormerPlugin(BasePlugin):
bg_upsampler=upscaler.model if upscaler is not None else None,
)
def __call__(self, rgb_np_img, files, form):
def __call__(self, rgb_np_img, req: RunPluginRequest):
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")

View File

@ -136,6 +136,12 @@ class InpaintRequest(BaseModel):
extender_height: int = Field(640, description="Extend height for extender")
extender_width: int = Field(640, description="Extend width for extender")
sd_scale: float = Field(
1.0,
description="Resize the image before doing sd inpainting, the area outside the mask will not lose quality.",
gt=0.0,
le=1.0,
)
sd_mask_blur: int = Field(
33,
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
@ -143,6 +149,7 @@ class InpaintRequest(BaseModel):
sd_strength: float = Field(
1.0,
description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
le=1.0,
)
sd_steps: int = Field(
50,
@ -202,7 +209,9 @@ class InpaintRequest(BaseModel):
# ControlNet
enable_controlnet: bool = Field(False, description="Enable controlnet")
controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale")
controlnet_conditioning_scale: float = Field(
0.4, description="Conditioning scale", gt=0.0, le=1.0
)
controlnet_method: str = Field(
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
)
@ -214,6 +223,8 @@ class InpaintRequest(BaseModel):
fitting_degree: float = Field(
1.0,
description="Control the fitting degree of the generated objects to the mask shape.",
gt=0.0,
le=1.0,
)
@field_validator("sd_seed")
@ -226,7 +237,7 @@ class InpaintRequest(BaseModel):
class RunPluginRequest(BaseModel):
name: str
image: Optional[str] = Field(..., description="base64 encoded image")
image: str = Field(..., description="base64 encoded image")
clicks: List[List[int]] = Field(
[], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
)

View File

@ -1,8 +1,11 @@
import hashlib
import os
import time
from PIL import Image
from lama_cleaner.helper import encode_pil_to_base64
from lama_cleaner.plugins.anime_seg import AnimeSeg
from lama_cleaner.schema import RunPluginRequest
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -22,6 +25,8 @@ img_p = current_dir / "bunny.jpeg"
img_bytes = open(img_p, "rb").read()
bgr_img = cv2.imread(str(img_p))
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
def _save(img, name):
@ -30,15 +35,18 @@ def _save(img, name):
def test_remove_bg():
model = RemoveBG()
res = model.forward(bgr_img)
res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
rgba_np_img = model(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
_save(res, "test_remove_bg.png")
def test_anime_seg():
model = AnimeSeg()
img = cv2.imread(str(current_dir / "anime_test.png"))
res = model.forward(img)
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
assert len(res.shape) == 3
assert res.shape[-1] == 4
_save(res, "test_anime_seg.png")
@ -48,10 +56,16 @@ def test_anime_seg():
def test_upscale(device):
check_device(device)
model = RealESRGANUpscaler("realesr-general-x4v3", device)
res = model.forward(bgr_img, 2)
res = model(
rgb_img,
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
)
_save(res, f"test_upscale_x2_{device}.png")
res = model.forward(bgr_img, 4)
res = model(
rgb_img,
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
)
_save(res, f"test_upscale_x4_{device}.png")
@ -59,7 +73,7 @@ def test_upscale(device):
def test_gfpgan(device):
check_device(device)
model = GFPGANPlugin(device)
res = model(rgb_img, None, None)
res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64))
_save(res, f"test_gfpgan_{device}.png")
@ -67,20 +81,24 @@ def test_gfpgan(device):
def test_restoreformer(device):
check_device(device)
model = RestoreFormerPlugin(device)
res = model(rgb_img, None, None)
res = model(
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
)
_save(res, f"test_restoreformer_{device}.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_segment_anything(device):
check_device(device)
img_md5 = hashlib.md5(img_bytes).hexdigest()
model = InteractiveSeg("vit_l", device)
new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
new_mask = model(
rgb_img,
RunPluginRequest(
name=InteractiveSeg.name,
image=rgb_img_base64,
clicks=([[448 // 2, 394 // 2, 1]]),
),
)
save_name = f"test_segment_anything_{device}.png"
_save(new_mask, save_name)
start = time.time()
model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
print(f"Time for {save_name}: {time.time() - start:.2f}s")

View File

@ -42,7 +42,7 @@ function Home() {
useEffect(() => {
const fetchServerConfig = async () => {
const serverConfig = await getServerConfig().then((res) => res.json())
const serverConfig = await getServerConfig()
setServerConfig(serverConfig)
if (serverConfig.isDesktop) {
// Keeping GUI Window Open

View File

@ -363,9 +363,6 @@ export default function Editor(props: EditorProps) {
undefined,
newClicks
)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const img = new Image()
img.onload = () => {

View File

@ -78,6 +78,7 @@ export default function FileManager(props: Props) {
const ref = useRef(null)
const debouncedSearchText = useDebounce(fileManagerState.searchText, 300)
const [tab, setTab] = useState(IMAGE_TAB)
const [filenames, setFilenames] = useState<Filename[]>([])
const [photos, setPhotos] = useState<Photo[]>([])
const [photoIndex, setPhotoIndex] = useState(0)
@ -131,13 +132,28 @@ export default function FileManager(props: Props) {
[open, closeScrollTop]
)
useEffect(() => {
const fetchData = async () => {
try {
const filenames = await getMedias(tab)
setFilenames(filenames)
} catch (e: any) {
toast({
variant: "destructive",
title: "Uh oh! Something went wrong.",
description: e.message ? e.message : e.toString(),
})
}
}
fetchData()
}, [tab])
useEffect(() => {
if (!open) {
return
}
const fetchData = async () => {
try {
const filenames = await getMedias(tab)
let filteredFilenames = filenames
if (debouncedSearchText) {
const fuse = new Fuse(filteredFilenames, {
@ -173,7 +189,7 @@ export default function FileManager(props: Props) {
}
}
fetchData()
}, [tab, debouncedSearchText, fileManagerState, photoWidth, open])
}, [filenames, debouncedSearchText, fileManagerState, photoWidth, open])
const onScroll = (event: SyntheticEvent) => {
setScrollTop(event.currentTarget.scrollTop)

View File

@ -99,7 +99,7 @@ export function SettingsDialog() {
},
})
function onSubmit(values: z.infer<typeof formSchema>) {
async function onSubmit(values: z.infer<typeof formSchema>) {
// Do something with the form values. ✅ This will be type-safe and validated.
updateSettings({
enableDownloadMask: values.enableDownloadMask,
@ -116,24 +116,22 @@ export function SettingsDialog() {
if (model.name !== settings.model.name) {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: true })
switchModel(model.name)
.then((res) => {
toast({
title: `Switch to ${model.name} success`,
})
setAppModel(model)
try {
const newModel = await switchModel(model.name)
toast({
title: `Switch to ${newModel.name} success`,
})
.catch((error: any) => {
toast({
variant: "destructive",
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
})
.finally(() => {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: false })
setAppModel(model)
} catch (error: any) {
toast({
variant: "destructive",
title: `Switch to ${model.name} failed: ${error}`,
})
setModel(settings.model)
} finally {
toggleOpenModelSwitching()
updateAppState({ disableShortCuts: false })
}
}
}

View File

@ -69,6 +69,27 @@ const DiffusionOptions = () => {
}
}
const renderCropper = () => {
return (
<RowContainer>
<LabelTitle
text="Cropper"
toolTip="Inpainting on part of image, improve inference speed and reduce memory usage."
/>
<Switch
id="cropper"
checked={settings.showCropper}
onCheckedChange={(value) => {
updateSettings({ showCropper: value })
if (value) {
updateSettings({ showExtender: false })
}
}}
/>
</RowContainer>
)
}
const renderConterNetSetting = () => {
if (!settings.model.support_controlnet) {
return null
@ -558,28 +579,8 @@ const DiffusionOptions = () => {
)
}
return (
<div className="flex flex-col gap-4 mt-4">
<RowContainer>
<LabelTitle
text="Cropper"
toolTip="Inpainting on part of image, improve inference speed and reduce memory usage."
/>
<Switch
id="cropper"
checked={settings.showCropper}
onCheckedChange={(value) => {
updateSettings({ showCropper: value })
if (value) {
updateSettings({ showExtender: false })
}
}}
/>
</RowContainer>
{renderExtender()}
{renderPowerPaintTaskType()}
const renderSteps = () => {
return (
<div className="flex flex-col gap-1">
<LabelTitle
htmlFor="steps"
@ -607,7 +608,11 @@ const DiffusionOptions = () => {
/>
</RowContainer>
</div>
)
}
const renderGuidanceScale = () => {
return (
<div className="flex flex-col gap-1">
<LabelTitle
text="Guidance scale"
@ -637,10 +642,11 @@ const DiffusionOptions = () => {
/>
</RowContainer>
</div>
)
}
{renderP2PImageGuidanceScale()}
{renderStrength()}
const renderSampler = () => {
return (
<RowContainer>
<LabelTitle text="Sampler" />
<Select
@ -664,7 +670,11 @@ const DiffusionOptions = () => {
</SelectContent>
</Select>
</RowContainer>
)
}
const renderSeed = () => {
return (
<RowContainer>
{/* 每次会从服务器返回更新该值 */}
<LabelTitle
@ -692,15 +702,11 @@ const DiffusionOptions = () => {
/>
</div>
</RowContainer>
)
}
{renderNegativePrompt()}
<Separator />
{renderConterNetSetting()}
{renderFreeu()}
{renderLCMLora()}
const renderMaskBlur = () => {
return (
<div className="flex flex-col gap-1">
<LabelTitle
text="Mask blur"
@ -727,24 +733,49 @@ const DiffusionOptions = () => {
/>
</RowContainer>
</div>
)
}
<RowContainer>
<LabelTitle
text="Match histograms"
toolTip="Match the inpainting result histogram to the source image histogram"
url="https://github.com/Sanster/lama-cleaner/pull/143#issuecomment-1325859307"
/>
<Switch
id="match-histograms"
checked={settings.sdMatchHistograms}
onCheckedChange={(value) => {
updateSettings({ sdMatchHistograms: value })
}}
/>
</RowContainer>
const renderMatchHistograms = () => {
return (
<>
<RowContainer>
<LabelTitle
text="Match histograms"
toolTip="Match the inpainting result histogram to the source image histogram"
url="https://github.com/Sanster/lama-cleaner/pull/143#issuecomment-1325859307"
/>
<Switch
id="match-histograms"
checked={settings.sdMatchHistograms}
onCheckedChange={(value) => {
updateSettings({ sdMatchHistograms: value })
}}
/>
</RowContainer>
<Separator />
</>
)
}
return (
<div className="flex flex-col gap-4 mt-4">
{renderCropper()}
{renderExtender()}
{renderPowerPaintTaskType()}
{renderSteps()}
{renderGuidanceScale()}
{renderP2PImageGuidanceScale()}
{renderStrength()}
{renderSampler()}
{renderSeed()}
{renderNegativePrompt()}
<Separator />
{renderConterNetSetting()}
{renderLCMLora()}
{renderMaskBlur()}
{renderMatchHistograms()}
{renderFreeu()}
{renderPaintByExample()}
</div>
)

View File

@ -14,11 +14,11 @@ const Workspace = () => {
])
useEffect(() => {
currentModel()
.then((res) => res.json())
.then((model) => {
updateSettings({ model })
})
const fetchCurrentModel = async () => {
const model = await currentModel()
updateSettings({ model })
}
fetchCurrentModel()
}, [])
return (

View File

@ -1,4 +1,11 @@
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
import {
Filename,
GenInfo,
ModelInfo,
PowerPaintTask,
Rect,
ServerConfig,
} from "@/lib/types"
import { Settings } from "@/lib/states"
import { convertToBase64, srcToFile } from "@/lib/utils"
import axios from "axios"
@ -24,124 +31,113 @@ export default async function inpaint(
const exampleImageBase64 = paintByExampleImage
? await convertToBase64(paintByExampleImage)
: null
try {
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
image: imageBase64,
mask: maskBase64,
ldm_steps: settings.ldmSteps,
ldm_sampler: settings.ldmSampler,
zits_wireframe: settings.zitsWireframe,
cv2_flag: settings.cv2Flag,
cv2_radius: settings.cv2Radius,
hd_strategy: "Crop",
hd_strategy_crop_triger_size: 640,
hd_strategy_crop_margin: 128,
hd_trategy_resize_imit: 2048,
prompt: settings.prompt,
negative_prompt: settings.negativePrompt,
use_croper: settings.showCropper,
croper_x: croperRect.x,
croper_y: croperRect.y,
croper_height: croperRect.height,
croper_width: croperRect.width,
use_extender: settings.showExtender,
extender_x: extenderState.x,
extender_y: extenderState.y,
extender_height: extenderState.height,
extender_width: extenderState.width,
sd_mask_blur: settings.sdMaskBlur,
sd_strength: settings.sdStrength,
sd_steps: settings.sdSteps,
sd_guidance_scale: settings.sdGuidanceScale,
sd_sampler: settings.sdSampler,
sd_seed: settings.seedFixed ? settings.seed : -1,
sd_match_histograms: settings.sdMatchHistograms,
sd_freeu: settings.enableFreeu,
sd_freeu_config: settings.freeuConfig,
sd_lcm_lora: settings.enableLCMLora,
paint_by_example_example_image: exampleImageBase64,
p2p_image_guidance_scale: settings.p2pImageGuidanceScale,
enable_controlnet: settings.enableControlnet,
controlnet_conditioning_scale: settings.controlnetConditioningScale,
controlnet_method: settings.controlnetMethod
? settings.controlnetMethod
: "",
powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting
: settings.powerpaintTask,
}),
})
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
image: imageBase64,
mask: maskBase64,
ldm_steps: settings.ldmSteps,
ldm_sampler: settings.ldmSampler,
zits_wireframe: settings.zitsWireframe,
cv2_flag: settings.cv2Flag,
cv2_radius: settings.cv2Radius,
hd_strategy: "Crop",
hd_strategy_crop_triger_size: 640,
hd_strategy_crop_margin: 128,
hd_trategy_resize_imit: 2048,
prompt: settings.prompt,
negative_prompt: settings.negativePrompt,
use_croper: settings.showCropper,
croper_x: croperRect.x,
croper_y: croperRect.y,
croper_height: croperRect.height,
croper_width: croperRect.width,
use_extender: settings.showExtender,
extender_x: extenderState.x,
extender_y: extenderState.y,
extender_height: extenderState.height,
extender_width: extenderState.width,
sd_mask_blur: settings.sdMaskBlur,
sd_strength: settings.sdStrength,
sd_steps: settings.sdSteps,
sd_guidance_scale: settings.sdGuidanceScale,
sd_sampler: settings.sdSampler,
sd_seed: settings.seedFixed ? settings.seed : -1,
sd_match_histograms: settings.sdMatchHistograms,
sd_freeu: settings.enableFreeu,
sd_freeu_config: settings.freeuConfig,
sd_lcm_lora: settings.enableLCMLora,
paint_by_example_example_image: exampleImageBase64,
p2p_image_guidance_scale: settings.p2pImageGuidanceScale,
enable_controlnet: settings.enableControlnet,
controlnet_conditioning_scale: settings.controlnetConditioningScale,
controlnet_method: settings.controlnetMethod
? settings.controlnetMethod
: "",
powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting
: settings.powerpaintTask,
}),
})
if (res.ok) {
const blob = await res.blob()
return {
blob: URL.createObjectURL(blob),
seed: res.headers.get("X-Seed"),
}
} catch (error: any) {
throw new Error(`Something went wrong: ${JSON.stringify(error.message)}`)
}
const errors = await res.json()
throw new Error(`Something went wrong: ${errors.errors}`)
}
export function getServerConfig() {
return fetch(`${API_ENDPOINT}/server-config`, {
method: "GET",
})
export async function getServerConfig(): Promise<ServerConfig> {
const res = await api.get(`/server-config`)
return res.data
}
export function switchModel(name: string) {
return axios.post(`${API_ENDPOINT}/model`, { name })
export async function switchModel(name: string): Promise<ModelInfo> {
const res = await api.post(`/model`, { name })
return res.data
}
export function currentModel() {
return fetch(`${API_ENDPOINT}/model`, {
method: "GET",
})
export async function currentModel(): Promise<ModelInfo> {
const res = await api.get("/model")
return res.data
}
export function fetchModelInfos(): Promise<ModelInfo[]> {
return api.get("/models").then((response) => response.data)
}
export function modelDownloaded(name: string) {
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
method: "GET",
})
}
export async function runPlugin(
name: string,
imageFile: File,
upscale?: number,
clicks?: number[][]
) {
const fd = new FormData()
fd.append("name", name)
fd.append("image", imageFile)
if (upscale) {
fd.append("upscale", upscale.toString())
}
if (clicks) {
fd.append("clicks", JSON.stringify(clicks))
}
try {
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
method: "POST",
body: fd,
})
if (res.ok) {
const blob = await res.blob()
return { blob: URL.createObjectURL(blob) }
}
const errMsg = await res.text()
throw new Error(errMsg)
} catch (error) {
throw new Error(`Something went wrong: ${error}`)
const imageBase64 = await convertToBase64(imageFile)
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
name,
image: imageBase64,
upscale,
clicks,
}),
})
if (res.ok) {
const blob = await res.blob()
return { blob: URL.createObjectURL(blob) }
}
const errMsg = await res.json()
throw new Error(errMsg)
}
export async function getMediaFile(tab: string, filename: string) {
@ -160,12 +156,12 @@ export async function getMediaFile(tab: string, filename: string) {
})
return file
}
const errMsg = await res.text()
throw new Error(errMsg)
const errMsg = await res.json()
throw new Error(errMsg.errors)
}
export async function getMedias(tab: string): Promise<Filename[]> {
const res = await axios.get(`${API_ENDPOINT}/medias`, { params: { tab } })
const res = await api.get(`medias`, { params: { tab } })
return res.data
}
@ -191,3 +187,10 @@ export async function downloadToOutput(
throw new Error(`Something went wrong: ${error}`)
}
}
export async function getGenInfo(file: File): Promise<GenInfo> {
const fd = new FormData()
fd.append("file", file)
const res = await api.post(`/gen-info`, fd)
return res.data
}

View File

@ -15,6 +15,7 @@ import {
Point,
PowerPaintTask,
SDSampler,
ServerConfig,
Size,
SortBy,
SortOrder,
@ -33,7 +34,7 @@ import {
loadImage,
srcToFile,
} from "./utils"
import inpaint, { runPlugin } from "./api"
import inpaint, { getGenInfo, runPlugin } from "./api"
import { toast } from "@/components/ui/use-toast"
type FileManagerState = {
@ -57,6 +58,7 @@ export type Settings = {
enableDownloadMask: boolean
enableManualInpainting: boolean
enableUploadMask: boolean
enableAutoExtractPrompt: boolean
showCropper: boolean
showExtender: boolean
extenderDirection: ExtenderDirection
@ -103,16 +105,6 @@ export type Settings = {
powerpaintTask: PowerPaintTask
}
type ServerConfig = {
plugins: string[]
enableFileManager: boolean
enableAutoSaving: boolean
enableControlnet: boolean
controlnetMethod: string
disableModelSwitch: boolean
isDesktop: boolean
}
type InteractiveSegState = {
isInteractiveSeg: boolean
interactiveSegMask: HTMLImageElement | null
@ -162,7 +154,7 @@ type AppState = {
type AppAction = {
updateAppState: (newState: Partial<AppState>) => void
setFile: (file: File) => void
setFile: (file: File) => Promise<void>
setCustomFile: (file: File) => void
setIsInpainting: (newValue: boolean) => void
setIsPluginRunning: (newValue: boolean) => void
@ -304,6 +296,7 @@ const defaultValues: AppState = {
enableDownloadMask: false,
enableManualInpainting: false,
enableUploadMask: false,
enableAutoExtractPrompt: true,
ldmSteps: 30,
ldmSampler: LDMSampler.ddim,
zitsWireframe: true,
@ -540,9 +533,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
const start = new Date()
const targetFile = await get().getCurrentTargetFile()
const res = await runPlugin(pluginName, targetFile, params.upscale)
if (!res) {
throw new Error("Something went wrong on server side.")
}
const { blob } = res
const newRender = new Image()
await loadImage(newRender, blob)
@ -818,7 +808,27 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
state.isPluginRunning = newValue
}),
setFile: (file: File) => {
setFile: async (file: File) => {
if (get().settings.enableAutoExtractPrompt) {
try {
const res = await getGenInfo(file)
if (res.prompt) {
set((state) => {
state.settings.prompt = res.prompt
})
}
if (res.negative_prompt) {
set((state) => {
state.settings.negativePrompt = res.negative_prompt
})
}
} catch (e: any) {
toast({
variant: "destructive",
description: e.message ? e.message : e.toString(),
})
}
}
set((state) => {
state.file = file
state.interactiveSegState = castDraft(

View File

@ -6,6 +6,21 @@ export interface Filename {
mtime: number
}
export interface ServerConfig {
plugins: string[]
enableFileManager: boolean
enableAutoSaving: boolean
enableControlnet: boolean
controlnetMethod: string
disableModelSwitch: boolean
isDesktop: boolean
}
export interface GenInfo {
prompt: string
negative_prompt: string
}
export interface ModelInfo {
name: string
path: string