update plugins
This commit is contained in:
parent
b0e009f879
commit
a2fd5bb3ea
@ -210,26 +210,26 @@ class Api:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def api_run_plugin(self, req: RunPluginRequest):
|
def api_run_plugin(self, req: RunPluginRequest):
|
||||||
|
ext = "png"
|
||||||
if req.name not in self.plugins:
|
if req.name not in self.plugins:
|
||||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||||
bgr_res = self.plugins[req.name].run(image, req)
|
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if req.name == InteractiveSeg.name:
|
if req.name == InteractiveSeg.name:
|
||||||
return Response(
|
return Response(
|
||||||
content=numpy_to_bytes(bgr_res, "png"),
|
content=numpy_to_bytes(bgr_np_img, ext),
|
||||||
media_type="image/png",
|
media_type=f"image/{ext}",
|
||||||
)
|
)
|
||||||
ext = "png"
|
if bgr_np_img.shape[2] == 4:
|
||||||
if req.name in [RemoveBG.name, AnimeSeg.name]:
|
rgba_np_img = bgr_np_img
|
||||||
rgb_res = bgr_res
|
|
||||||
else:
|
else:
|
||||||
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
|
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
|
||||||
rgb_res = concat_alpha_channel(rgb_res, alpha_channel)
|
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
content=pil_to_bytes(
|
content=pil_to_bytes(
|
||||||
Image.fromarray(rgb_res),
|
Image.fromarray(rgba_np_img),
|
||||||
ext=ext,
|
ext=ext,
|
||||||
quality=self.config.quality,
|
quality=self.config.quality,
|
||||||
infos=infos,
|
infos=infos,
|
||||||
|
@ -7,6 +7,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from lama_cleaner.helper import load_model
|
from lama_cleaner.helper import load_model
|
||||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
|
|
||||||
class REBNCONV(nn.Module):
|
class REBNCONV(nn.Module):
|
||||||
@ -425,7 +426,7 @@ class AnimeSeg(BasePlugin):
|
|||||||
ANIME_SEG_MODELS["md5"],
|
ANIME_SEG_MODELS["md5"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||||
return self.forward(rgb_np_img)
|
return self.forward(rgb_np_img)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
@ -8,7 +11,8 @@ class BasePlugin:
|
|||||||
logger.error(err_msg)
|
logger.error(err_msg)
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
|
||||||
|
# return RGBA np image or BGR np image
|
||||||
...
|
...
|
||||||
|
|
||||||
def check_dep(self):
|
def check_dep(self):
|
||||||
|
@ -3,6 +3,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
|
|
||||||
class GFPGANPlugin(BasePlugin):
|
class GFPGANPlugin(BasePlugin):
|
||||||
@ -36,7 +37,7 @@ class GFPGANPlugin(BasePlugin):
|
|||||||
self.face_enhancer.face_helper.face_det.to(device)
|
self.face_enhancer.face_helper.face_det.to(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||||
weight = 0.5
|
weight = 0.5
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,6 +9,7 @@ from loguru import logger
|
|||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
|
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
# 从小到大
|
# 从小到大
|
||||||
SEGMENT_ANYTHING_MODELS = {
|
SEGMENT_ANYTHING_MODELS = {
|
||||||
@ -44,11 +47,11 @@ class InteractiveSeg(BasePlugin):
|
|||||||
)
|
)
|
||||||
self.prev_img_md5 = None
|
self.prev_img_md5 = None
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||||
clicks = json.loads(form["clicks"])
|
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||||
return self.forward(rgb_np_img, clicks, form["img_md5"])
|
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||||
|
|
||||||
def forward(self, rgb_np_img, clicks, img_md5):
|
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||||
input_point = []
|
input_point = []
|
||||||
input_label = []
|
input_label = []
|
||||||
for click in clicks:
|
for click in clicks:
|
||||||
|
@ -6,6 +6,7 @@ from loguru import logger
|
|||||||
from lama_cleaner.const import RealESRGANModel
|
from lama_cleaner.const import RealESRGANModel
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANUpscaler(BasePlugin):
|
class RealESRGANUpscaler(BasePlugin):
|
||||||
@ -76,11 +77,10 @@ class RealESRGANUpscaler(BasePlugin):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
scale = float(form["upscale"])
|
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
|
result = self.forward(bgr_np_img, req.scale)
|
||||||
result = self.forward(bgr_np_img, scale)
|
|
||||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
|
|
||||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
|
|
||||||
class RemoveBG(BasePlugin):
|
class RemoveBG(BasePlugin):
|
||||||
@ -19,7 +20,7 @@ class RemoveBG(BasePlugin):
|
|||||||
|
|
||||||
self.session = new_session(model_name="u2net")
|
self.session = new_session(model_name="u2net")
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
return self.forward(bgr_np_img)
|
return self.forward(bgr_np_img)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
|
|
||||||
|
|
||||||
class RestoreFormerPlugin(BasePlugin):
|
class RestoreFormerPlugin(BasePlugin):
|
||||||
@ -31,7 +32,7 @@ class RestoreFormerPlugin(BasePlugin):
|
|||||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||||
weight = 0.5
|
weight = 0.5
|
||||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||||
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
||||||
|
@ -136,6 +136,12 @@ class InpaintRequest(BaseModel):
|
|||||||
extender_height: int = Field(640, description="Extend height for extender")
|
extender_height: int = Field(640, description="Extend height for extender")
|
||||||
extender_width: int = Field(640, description="Extend width for extender")
|
extender_width: int = Field(640, description="Extend width for extender")
|
||||||
|
|
||||||
|
sd_scale: float = Field(
|
||||||
|
1.0,
|
||||||
|
description="Resize the image before doing sd inpainting, the area outside the mask will not lose quality.",
|
||||||
|
gt=0.0,
|
||||||
|
le=1.0,
|
||||||
|
)
|
||||||
sd_mask_blur: int = Field(
|
sd_mask_blur: int = Field(
|
||||||
33,
|
33,
|
||||||
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
|
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
|
||||||
@ -143,6 +149,7 @@ class InpaintRequest(BaseModel):
|
|||||||
sd_strength: float = Field(
|
sd_strength: float = Field(
|
||||||
1.0,
|
1.0,
|
||||||
description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
|
description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
|
||||||
|
le=1.0,
|
||||||
)
|
)
|
||||||
sd_steps: int = Field(
|
sd_steps: int = Field(
|
||||||
50,
|
50,
|
||||||
@ -202,7 +209,9 @@ class InpaintRequest(BaseModel):
|
|||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
enable_controlnet: bool = Field(False, description="Enable controlnet")
|
enable_controlnet: bool = Field(False, description="Enable controlnet")
|
||||||
controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale")
|
controlnet_conditioning_scale: float = Field(
|
||||||
|
0.4, description="Conditioning scale", gt=0.0, le=1.0
|
||||||
|
)
|
||||||
controlnet_method: str = Field(
|
controlnet_method: str = Field(
|
||||||
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||||
)
|
)
|
||||||
@ -214,6 +223,8 @@ class InpaintRequest(BaseModel):
|
|||||||
fitting_degree: float = Field(
|
fitting_degree: float = Field(
|
||||||
1.0,
|
1.0,
|
||||||
description="Control the fitting degree of the generated objects to the mask shape.",
|
description="Control the fitting degree of the generated objects to the mask shape.",
|
||||||
|
gt=0.0,
|
||||||
|
le=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("sd_seed")
|
@field_validator("sd_seed")
|
||||||
@ -226,7 +237,7 @@ class InpaintRequest(BaseModel):
|
|||||||
|
|
||||||
class RunPluginRequest(BaseModel):
|
class RunPluginRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
image: Optional[str] = Field(..., description="base64 encoded image")
|
image: str = Field(..., description="base64 encoded image")
|
||||||
clicks: List[List[int]] = Field(
|
clicks: List[List[int]] = Field(
|
||||||
[], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
|
[], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
|
||||||
)
|
)
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lama_cleaner.helper import encode_pil_to_base64
|
||||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
||||||
|
from lama_cleaner.schema import RunPluginRequest
|
||||||
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
|
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
@ -22,6 +25,8 @@ img_p = current_dir / "bunny.jpeg"
|
|||||||
img_bytes = open(img_p, "rb").read()
|
img_bytes = open(img_p, "rb").read()
|
||||||
bgr_img = cv2.imread(str(img_p))
|
bgr_img = cv2.imread(str(img_p))
|
||||||
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
||||||
|
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
|
||||||
|
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
|
||||||
|
|
||||||
|
|
||||||
def _save(img, name):
|
def _save(img, name):
|
||||||
@ -30,15 +35,18 @@ def _save(img, name):
|
|||||||
|
|
||||||
def test_remove_bg():
|
def test_remove_bg():
|
||||||
model = RemoveBG()
|
model = RemoveBG()
|
||||||
res = model.forward(bgr_img)
|
rgba_np_img = model(
|
||||||
res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
|
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||||
|
)
|
||||||
|
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
||||||
_save(res, "test_remove_bg.png")
|
_save(res, "test_remove_bg.png")
|
||||||
|
|
||||||
|
|
||||||
def test_anime_seg():
|
def test_anime_seg():
|
||||||
model = AnimeSeg()
|
model = AnimeSeg()
|
||||||
img = cv2.imread(str(current_dir / "anime_test.png"))
|
img = cv2.imread(str(current_dir / "anime_test.png"))
|
||||||
res = model.forward(img)
|
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
|
||||||
|
res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||||
assert len(res.shape) == 3
|
assert len(res.shape) == 3
|
||||||
assert res.shape[-1] == 4
|
assert res.shape[-1] == 4
|
||||||
_save(res, "test_anime_seg.png")
|
_save(res, "test_anime_seg.png")
|
||||||
@ -48,10 +56,16 @@ def test_anime_seg():
|
|||||||
def test_upscale(device):
|
def test_upscale(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
||||||
res = model.forward(bgr_img, 2)
|
res = model(
|
||||||
|
rgb_img,
|
||||||
|
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
|
||||||
|
)
|
||||||
_save(res, f"test_upscale_x2_{device}.png")
|
_save(res, f"test_upscale_x2_{device}.png")
|
||||||
|
|
||||||
res = model.forward(bgr_img, 4)
|
res = model(
|
||||||
|
rgb_img,
|
||||||
|
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
|
||||||
|
)
|
||||||
_save(res, f"test_upscale_x4_{device}.png")
|
_save(res, f"test_upscale_x4_{device}.png")
|
||||||
|
|
||||||
|
|
||||||
@ -59,7 +73,7 @@ def test_upscale(device):
|
|||||||
def test_gfpgan(device):
|
def test_gfpgan(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = GFPGANPlugin(device)
|
model = GFPGANPlugin(device)
|
||||||
res = model(rgb_img, None, None)
|
res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64))
|
||||||
_save(res, f"test_gfpgan_{device}.png")
|
_save(res, f"test_gfpgan_{device}.png")
|
||||||
|
|
||||||
|
|
||||||
@ -67,20 +81,24 @@ def test_gfpgan(device):
|
|||||||
def test_restoreformer(device):
|
def test_restoreformer(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
model = RestoreFormerPlugin(device)
|
model = RestoreFormerPlugin(device)
|
||||||
res = model(rgb_img, None, None)
|
res = model(
|
||||||
|
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
||||||
|
)
|
||||||
_save(res, f"test_restoreformer_{device}.png")
|
_save(res, f"test_restoreformer_{device}.png")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||||
def test_segment_anything(device):
|
def test_segment_anything(device):
|
||||||
check_device(device)
|
check_device(device)
|
||||||
img_md5 = hashlib.md5(img_bytes).hexdigest()
|
|
||||||
model = InteractiveSeg("vit_l", device)
|
model = InteractiveSeg("vit_l", device)
|
||||||
new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
|
new_mask = model(
|
||||||
|
rgb_img,
|
||||||
|
RunPluginRequest(
|
||||||
|
name=InteractiveSeg.name,
|
||||||
|
image=rgb_img_base64,
|
||||||
|
clicks=([[448 // 2, 394 // 2, 1]]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
save_name = f"test_segment_anything_{device}.png"
|
save_name = f"test_segment_anything_{device}.png"
|
||||||
_save(new_mask, save_name)
|
_save(new_mask, save_name)
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
|
|
||||||
print(f"Time for {save_name}: {time.time() - start:.2f}s")
|
|
||||||
|
@ -42,7 +42,7 @@ function Home() {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchServerConfig = async () => {
|
const fetchServerConfig = async () => {
|
||||||
const serverConfig = await getServerConfig().then((res) => res.json())
|
const serverConfig = await getServerConfig()
|
||||||
setServerConfig(serverConfig)
|
setServerConfig(serverConfig)
|
||||||
if (serverConfig.isDesktop) {
|
if (serverConfig.isDesktop) {
|
||||||
// Keeping GUI Window Open
|
// Keeping GUI Window Open
|
||||||
|
@ -363,9 +363,6 @@ export default function Editor(props: EditorProps) {
|
|||||||
undefined,
|
undefined,
|
||||||
newClicks
|
newClicks
|
||||||
)
|
)
|
||||||
if (!res) {
|
|
||||||
throw new Error("Something went wrong on server side.")
|
|
||||||
}
|
|
||||||
const { blob } = res
|
const { blob } = res
|
||||||
const img = new Image()
|
const img = new Image()
|
||||||
img.onload = () => {
|
img.onload = () => {
|
||||||
|
@ -78,6 +78,7 @@ export default function FileManager(props: Props) {
|
|||||||
const ref = useRef(null)
|
const ref = useRef(null)
|
||||||
const debouncedSearchText = useDebounce(fileManagerState.searchText, 300)
|
const debouncedSearchText = useDebounce(fileManagerState.searchText, 300)
|
||||||
const [tab, setTab] = useState(IMAGE_TAB)
|
const [tab, setTab] = useState(IMAGE_TAB)
|
||||||
|
const [filenames, setFilenames] = useState<Filename[]>([])
|
||||||
const [photos, setPhotos] = useState<Photo[]>([])
|
const [photos, setPhotos] = useState<Photo[]>([])
|
||||||
const [photoIndex, setPhotoIndex] = useState(0)
|
const [photoIndex, setPhotoIndex] = useState(0)
|
||||||
|
|
||||||
@ -131,13 +132,28 @@ export default function FileManager(props: Props) {
|
|||||||
[open, closeScrollTop]
|
[open, closeScrollTop]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const fetchData = async () => {
|
||||||
|
try {
|
||||||
|
const filenames = await getMedias(tab)
|
||||||
|
setFilenames(filenames)
|
||||||
|
} catch (e: any) {
|
||||||
|
toast({
|
||||||
|
variant: "destructive",
|
||||||
|
title: "Uh oh! Something went wrong.",
|
||||||
|
description: e.message ? e.message : e.toString(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fetchData()
|
||||||
|
}, [tab])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!open) {
|
if (!open) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const fetchData = async () => {
|
const fetchData = async () => {
|
||||||
try {
|
try {
|
||||||
const filenames = await getMedias(tab)
|
|
||||||
let filteredFilenames = filenames
|
let filteredFilenames = filenames
|
||||||
if (debouncedSearchText) {
|
if (debouncedSearchText) {
|
||||||
const fuse = new Fuse(filteredFilenames, {
|
const fuse = new Fuse(filteredFilenames, {
|
||||||
@ -173,7 +189,7 @@ export default function FileManager(props: Props) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
fetchData()
|
fetchData()
|
||||||
}, [tab, debouncedSearchText, fileManagerState, photoWidth, open])
|
}, [filenames, debouncedSearchText, fileManagerState, photoWidth, open])
|
||||||
|
|
||||||
const onScroll = (event: SyntheticEvent) => {
|
const onScroll = (event: SyntheticEvent) => {
|
||||||
setScrollTop(event.currentTarget.scrollTop)
|
setScrollTop(event.currentTarget.scrollTop)
|
||||||
|
@ -99,7 +99,7 @@ export function SettingsDialog() {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
function onSubmit(values: z.infer<typeof formSchema>) {
|
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||||
// Do something with the form values. ✅ This will be type-safe and validated.
|
// Do something with the form values. ✅ This will be type-safe and validated.
|
||||||
updateSettings({
|
updateSettings({
|
||||||
enableDownloadMask: values.enableDownloadMask,
|
enableDownloadMask: values.enableDownloadMask,
|
||||||
@ -116,24 +116,22 @@ export function SettingsDialog() {
|
|||||||
if (model.name !== settings.model.name) {
|
if (model.name !== settings.model.name) {
|
||||||
toggleOpenModelSwitching()
|
toggleOpenModelSwitching()
|
||||||
updateAppState({ disableShortCuts: true })
|
updateAppState({ disableShortCuts: true })
|
||||||
switchModel(model.name)
|
try {
|
||||||
.then((res) => {
|
const newModel = await switchModel(model.name)
|
||||||
toast({
|
toast({
|
||||||
title: `Switch to ${model.name} success`,
|
title: `Switch to ${newModel.name} success`,
|
||||||
})
|
})
|
||||||
setAppModel(model)
|
setAppModel(model)
|
||||||
})
|
} catch (error: any) {
|
||||||
.catch((error: any) => {
|
|
||||||
toast({
|
toast({
|
||||||
variant: "destructive",
|
variant: "destructive",
|
||||||
title: `Switch to ${model.name} failed: ${error}`,
|
title: `Switch to ${model.name} failed: ${error}`,
|
||||||
})
|
})
|
||||||
setModel(settings.model)
|
setModel(settings.model)
|
||||||
})
|
} finally {
|
||||||
.finally(() => {
|
|
||||||
toggleOpenModelSwitching()
|
toggleOpenModelSwitching()
|
||||||
updateAppState({ disableShortCuts: false })
|
updateAppState({ disableShortCuts: false })
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,6 +69,27 @@ const DiffusionOptions = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderCropper = () => {
|
||||||
|
return (
|
||||||
|
<RowContainer>
|
||||||
|
<LabelTitle
|
||||||
|
text="Cropper"
|
||||||
|
toolTip="Inpainting on part of image, improve inference speed and reduce memory usage."
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
id="cropper"
|
||||||
|
checked={settings.showCropper}
|
||||||
|
onCheckedChange={(value) => {
|
||||||
|
updateSettings({ showCropper: value })
|
||||||
|
if (value) {
|
||||||
|
updateSettings({ showExtender: false })
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderConterNetSetting = () => {
|
const renderConterNetSetting = () => {
|
||||||
if (!settings.model.support_controlnet) {
|
if (!settings.model.support_controlnet) {
|
||||||
return null
|
return null
|
||||||
@ -558,28 +579,8 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderSteps = () => {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4 mt-4">
|
|
||||||
<RowContainer>
|
|
||||||
<LabelTitle
|
|
||||||
text="Cropper"
|
|
||||||
toolTip="Inpainting on part of image, improve inference speed and reduce memory usage."
|
|
||||||
/>
|
|
||||||
<Switch
|
|
||||||
id="cropper"
|
|
||||||
checked={settings.showCropper}
|
|
||||||
onCheckedChange={(value) => {
|
|
||||||
updateSettings({ showCropper: value })
|
|
||||||
if (value) {
|
|
||||||
updateSettings({ showExtender: false })
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</RowContainer>
|
|
||||||
|
|
||||||
{renderExtender()}
|
|
||||||
{renderPowerPaintTaskType()}
|
|
||||||
|
|
||||||
<div className="flex flex-col gap-1">
|
<div className="flex flex-col gap-1">
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
htmlFor="steps"
|
htmlFor="steps"
|
||||||
@ -607,7 +608,11 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderGuidanceScale = () => {
|
||||||
|
return (
|
||||||
<div className="flex flex-col gap-1">
|
<div className="flex flex-col gap-1">
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Guidance scale"
|
text="Guidance scale"
|
||||||
@ -637,10 +642,11 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
{renderP2PImageGuidanceScale()}
|
const renderSampler = () => {
|
||||||
{renderStrength()}
|
return (
|
||||||
|
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle text="Sampler" />
|
<LabelTitle text="Sampler" />
|
||||||
<Select
|
<Select
|
||||||
@ -664,7 +670,11 @@ const DiffusionOptions = () => {
|
|||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderSeed = () => {
|
||||||
|
return (
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
{/* 每次会从服务器返回更新该值 */}
|
{/* 每次会从服务器返回更新该值 */}
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
@ -692,15 +702,11 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
{renderNegativePrompt()}
|
const renderMaskBlur = () => {
|
||||||
|
return (
|
||||||
<Separator />
|
|
||||||
|
|
||||||
{renderConterNetSetting()}
|
|
||||||
{renderFreeu()}
|
|
||||||
{renderLCMLora()}
|
|
||||||
|
|
||||||
<div className="flex flex-col gap-1">
|
<div className="flex flex-col gap-1">
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Mask blur"
|
text="Mask blur"
|
||||||
@ -727,7 +733,12 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderMatchHistograms = () => {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Match histograms"
|
text="Match histograms"
|
||||||
@ -742,9 +753,29 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
|
||||||
<Separator />
|
<Separator />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-4 mt-4">
|
||||||
|
{renderCropper()}
|
||||||
|
{renderExtender()}
|
||||||
|
{renderPowerPaintTaskType()}
|
||||||
|
{renderSteps()}
|
||||||
|
{renderGuidanceScale()}
|
||||||
|
{renderP2PImageGuidanceScale()}
|
||||||
|
{renderStrength()}
|
||||||
|
{renderSampler()}
|
||||||
|
{renderSeed()}
|
||||||
|
{renderNegativePrompt()}
|
||||||
|
<Separator />
|
||||||
|
{renderConterNetSetting()}
|
||||||
|
{renderLCMLora()}
|
||||||
|
{renderMaskBlur()}
|
||||||
|
{renderMatchHistograms()}
|
||||||
|
{renderFreeu()}
|
||||||
{renderPaintByExample()}
|
{renderPaintByExample()}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -14,11 +14,11 @@ const Workspace = () => {
|
|||||||
])
|
])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
currentModel()
|
const fetchCurrentModel = async () => {
|
||||||
.then((res) => res.json())
|
const model = await currentModel()
|
||||||
.then((model) => {
|
|
||||||
updateSettings({ model })
|
updateSettings({ model })
|
||||||
})
|
}
|
||||||
|
fetchCurrentModel()
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -1,4 +1,11 @@
|
|||||||
import { Filename, ModelInfo, PowerPaintTask, Rect } from "@/lib/types"
|
import {
|
||||||
|
Filename,
|
||||||
|
GenInfo,
|
||||||
|
ModelInfo,
|
||||||
|
PowerPaintTask,
|
||||||
|
Rect,
|
||||||
|
ServerConfig,
|
||||||
|
} from "@/lib/types"
|
||||||
import { Settings } from "@/lib/states"
|
import { Settings } from "@/lib/states"
|
||||||
import { convertToBase64, srcToFile } from "@/lib/utils"
|
import { convertToBase64, srcToFile } from "@/lib/utils"
|
||||||
import axios from "axios"
|
import axios from "axios"
|
||||||
@ -24,7 +31,7 @@ export default async function inpaint(
|
|||||||
const exampleImageBase64 = paintByExampleImage
|
const exampleImageBase64 = paintByExampleImage
|
||||||
? await convertToBase64(paintByExampleImage)
|
? await convertToBase64(paintByExampleImage)
|
||||||
: null
|
: null
|
||||||
try {
|
|
||||||
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
@ -76,72 +83,61 @@ export default async function inpaint(
|
|||||||
: settings.powerpaintTask,
|
: settings.powerpaintTask,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
if (res.ok) {
|
||||||
const blob = await res.blob()
|
const blob = await res.blob()
|
||||||
return {
|
return {
|
||||||
blob: URL.createObjectURL(blob),
|
blob: URL.createObjectURL(blob),
|
||||||
seed: res.headers.get("X-Seed"),
|
seed: res.headers.get("X-Seed"),
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
|
||||||
throw new Error(`Something went wrong: ${JSON.stringify(error.message)}`)
|
|
||||||
}
|
}
|
||||||
|
const errors = await res.json()
|
||||||
|
throw new Error(`Something went wrong: ${errors.errors}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getServerConfig() {
|
export async function getServerConfig(): Promise<ServerConfig> {
|
||||||
return fetch(`${API_ENDPOINT}/server-config`, {
|
const res = await api.get(`/server-config`)
|
||||||
method: "GET",
|
return res.data
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function switchModel(name: string) {
|
export async function switchModel(name: string): Promise<ModelInfo> {
|
||||||
return axios.post(`${API_ENDPOINT}/model`, { name })
|
const res = await api.post(`/model`, { name })
|
||||||
|
return res.data
|
||||||
}
|
}
|
||||||
|
|
||||||
export function currentModel() {
|
export async function currentModel(): Promise<ModelInfo> {
|
||||||
return fetch(`${API_ENDPOINT}/model`, {
|
const res = await api.get("/model")
|
||||||
method: "GET",
|
return res.data
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function fetchModelInfos(): Promise<ModelInfo[]> {
|
export function fetchModelInfos(): Promise<ModelInfo[]> {
|
||||||
return api.get("/models").then((response) => response.data)
|
return api.get("/models").then((response) => response.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
export function modelDownloaded(name: string) {
|
|
||||||
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
|
|
||||||
method: "GET",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function runPlugin(
|
export async function runPlugin(
|
||||||
name: string,
|
name: string,
|
||||||
imageFile: File,
|
imageFile: File,
|
||||||
upscale?: number,
|
upscale?: number,
|
||||||
clicks?: number[][]
|
clicks?: number[][]
|
||||||
) {
|
) {
|
||||||
const fd = new FormData()
|
const imageBase64 = await convertToBase64(imageFile)
|
||||||
fd.append("name", name)
|
|
||||||
fd.append("image", imageFile)
|
|
||||||
if (upscale) {
|
|
||||||
fd.append("upscale", upscale.toString())
|
|
||||||
}
|
|
||||||
if (clicks) {
|
|
||||||
fd.append("clicks", JSON.stringify(clicks))
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
|
const res = await fetch(`${API_ENDPOINT}/run_plugin`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: fd,
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
name,
|
||||||
|
image: imageBase64,
|
||||||
|
upscale,
|
||||||
|
clicks,
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
if (res.ok) {
|
if (res.ok) {
|
||||||
const blob = await res.blob()
|
const blob = await res.blob()
|
||||||
return { blob: URL.createObjectURL(blob) }
|
return { blob: URL.createObjectURL(blob) }
|
||||||
}
|
}
|
||||||
const errMsg = await res.text()
|
const errMsg = await res.json()
|
||||||
throw new Error(errMsg)
|
throw new Error(errMsg)
|
||||||
} catch (error) {
|
|
||||||
throw new Error(`Something went wrong: ${error}`)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getMediaFile(tab: string, filename: string) {
|
export async function getMediaFile(tab: string, filename: string) {
|
||||||
@ -160,12 +156,12 @@ export async function getMediaFile(tab: string, filename: string) {
|
|||||||
})
|
})
|
||||||
return file
|
return file
|
||||||
}
|
}
|
||||||
const errMsg = await res.text()
|
const errMsg = await res.json()
|
||||||
throw new Error(errMsg)
|
throw new Error(errMsg.errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getMedias(tab: string): Promise<Filename[]> {
|
export async function getMedias(tab: string): Promise<Filename[]> {
|
||||||
const res = await axios.get(`${API_ENDPOINT}/medias`, { params: { tab } })
|
const res = await api.get(`medias`, { params: { tab } })
|
||||||
return res.data
|
return res.data
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,3 +187,10 @@ export async function downloadToOutput(
|
|||||||
throw new Error(`Something went wrong: ${error}`)
|
throw new Error(`Something went wrong: ${error}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getGenInfo(file: File): Promise<GenInfo> {
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append("file", file)
|
||||||
|
const res = await api.post(`/gen-info`, fd)
|
||||||
|
return res.data
|
||||||
|
}
|
||||||
|
@ -15,6 +15,7 @@ import {
|
|||||||
Point,
|
Point,
|
||||||
PowerPaintTask,
|
PowerPaintTask,
|
||||||
SDSampler,
|
SDSampler,
|
||||||
|
ServerConfig,
|
||||||
Size,
|
Size,
|
||||||
SortBy,
|
SortBy,
|
||||||
SortOrder,
|
SortOrder,
|
||||||
@ -33,7 +34,7 @@ import {
|
|||||||
loadImage,
|
loadImage,
|
||||||
srcToFile,
|
srcToFile,
|
||||||
} from "./utils"
|
} from "./utils"
|
||||||
import inpaint, { runPlugin } from "./api"
|
import inpaint, { getGenInfo, runPlugin } from "./api"
|
||||||
import { toast } from "@/components/ui/use-toast"
|
import { toast } from "@/components/ui/use-toast"
|
||||||
|
|
||||||
type FileManagerState = {
|
type FileManagerState = {
|
||||||
@ -57,6 +58,7 @@ export type Settings = {
|
|||||||
enableDownloadMask: boolean
|
enableDownloadMask: boolean
|
||||||
enableManualInpainting: boolean
|
enableManualInpainting: boolean
|
||||||
enableUploadMask: boolean
|
enableUploadMask: boolean
|
||||||
|
enableAutoExtractPrompt: boolean
|
||||||
showCropper: boolean
|
showCropper: boolean
|
||||||
showExtender: boolean
|
showExtender: boolean
|
||||||
extenderDirection: ExtenderDirection
|
extenderDirection: ExtenderDirection
|
||||||
@ -103,16 +105,6 @@ export type Settings = {
|
|||||||
powerpaintTask: PowerPaintTask
|
powerpaintTask: PowerPaintTask
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig = {
|
|
||||||
plugins: string[]
|
|
||||||
enableFileManager: boolean
|
|
||||||
enableAutoSaving: boolean
|
|
||||||
enableControlnet: boolean
|
|
||||||
controlnetMethod: string
|
|
||||||
disableModelSwitch: boolean
|
|
||||||
isDesktop: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
type InteractiveSegState = {
|
type InteractiveSegState = {
|
||||||
isInteractiveSeg: boolean
|
isInteractiveSeg: boolean
|
||||||
interactiveSegMask: HTMLImageElement | null
|
interactiveSegMask: HTMLImageElement | null
|
||||||
@ -162,7 +154,7 @@ type AppState = {
|
|||||||
|
|
||||||
type AppAction = {
|
type AppAction = {
|
||||||
updateAppState: (newState: Partial<AppState>) => void
|
updateAppState: (newState: Partial<AppState>) => void
|
||||||
setFile: (file: File) => void
|
setFile: (file: File) => Promise<void>
|
||||||
setCustomFile: (file: File) => void
|
setCustomFile: (file: File) => void
|
||||||
setIsInpainting: (newValue: boolean) => void
|
setIsInpainting: (newValue: boolean) => void
|
||||||
setIsPluginRunning: (newValue: boolean) => void
|
setIsPluginRunning: (newValue: boolean) => void
|
||||||
@ -304,6 +296,7 @@ const defaultValues: AppState = {
|
|||||||
enableDownloadMask: false,
|
enableDownloadMask: false,
|
||||||
enableManualInpainting: false,
|
enableManualInpainting: false,
|
||||||
enableUploadMask: false,
|
enableUploadMask: false,
|
||||||
|
enableAutoExtractPrompt: true,
|
||||||
ldmSteps: 30,
|
ldmSteps: 30,
|
||||||
ldmSampler: LDMSampler.ddim,
|
ldmSampler: LDMSampler.ddim,
|
||||||
zitsWireframe: true,
|
zitsWireframe: true,
|
||||||
@ -540,9 +533,6 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
const start = new Date()
|
const start = new Date()
|
||||||
const targetFile = await get().getCurrentTargetFile()
|
const targetFile = await get().getCurrentTargetFile()
|
||||||
const res = await runPlugin(pluginName, targetFile, params.upscale)
|
const res = await runPlugin(pluginName, targetFile, params.upscale)
|
||||||
if (!res) {
|
|
||||||
throw new Error("Something went wrong on server side.")
|
|
||||||
}
|
|
||||||
const { blob } = res
|
const { blob } = res
|
||||||
const newRender = new Image()
|
const newRender = new Image()
|
||||||
await loadImage(newRender, blob)
|
await loadImage(newRender, blob)
|
||||||
@ -818,7 +808,27 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
state.isPluginRunning = newValue
|
state.isPluginRunning = newValue
|
||||||
}),
|
}),
|
||||||
|
|
||||||
setFile: (file: File) => {
|
setFile: async (file: File) => {
|
||||||
|
if (get().settings.enableAutoExtractPrompt) {
|
||||||
|
try {
|
||||||
|
const res = await getGenInfo(file)
|
||||||
|
if (res.prompt) {
|
||||||
|
set((state) => {
|
||||||
|
state.settings.prompt = res.prompt
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (res.negative_prompt) {
|
||||||
|
set((state) => {
|
||||||
|
state.settings.negativePrompt = res.negative_prompt
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} catch (e: any) {
|
||||||
|
toast({
|
||||||
|
variant: "destructive",
|
||||||
|
description: e.message ? e.message : e.toString(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.file = file
|
state.file = file
|
||||||
state.interactiveSegState = castDraft(
|
state.interactiveSegState = castDraft(
|
||||||
|
@ -6,6 +6,21 @@ export interface Filename {
|
|||||||
mtime: number
|
mtime: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ServerConfig {
|
||||||
|
plugins: string[]
|
||||||
|
enableFileManager: boolean
|
||||||
|
enableAutoSaving: boolean
|
||||||
|
enableControlnet: boolean
|
||||||
|
controlnetMethod: string
|
||||||
|
disableModelSwitch: boolean
|
||||||
|
isDesktop: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GenInfo {
|
||||||
|
prompt: string
|
||||||
|
negative_prompt: string
|
||||||
|
}
|
||||||
|
|
||||||
export interface ModelInfo {
|
export interface ModelInfo {
|
||||||
name: string
|
name: string
|
||||||
path: string
|
path: string
|
||||||
|
Loading…
Reference in New Issue
Block a user