diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index 0a6697e..775bc46 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -6,8 +6,6 @@ import warnings warnings.simplefilter("ignore", UserWarning) -from lama_cleaner.parse_args import parse_args - def entry_point(): # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 486fd11..b506349 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -23,14 +23,6 @@ AVAILABLE_MODELS = [ "fcf", "manga", "cv2", - "sd1.5", - "anything4", - "realisticVision1.4", - "sd2", - "sdxl", - "paint_by_example", - "instruct_pix2pix", - "kandinsky2.2", ] DIFFUSERS_MODEL_FP16_REVERSION = [ "runwayml/stable-diffusion-inpainting", diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py index aef26c4..698f187 100644 --- a/lama_cleaner/download.py +++ b/lama_cleaner/download.py @@ -41,27 +41,6 @@ def folder_name_to_show_name(name: str) -> str: return name.replace("models--", "").replace("--", "/") -def scan_diffusers_models( - cache_dir, class_name: List[str], model_type: ModelType -) -> List[ModelInfo]: - cache_dir = Path(cache_dir) - res = [] - for it in cache_dir.glob("**/*/model_index.json"): - with open(it, "r", encoding="utf-8") as f: - data = json.load(f) - if data["_class_name"] in class_name: - name = folder_name_to_show_name(it.parent.parent.parent.name) - if name not in res: - res.append( - ModelInfo( - name=name, - path=name, - model_type=model_type, - ) - ) - return res - - def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: cache_dir = Path(cache_dir) res = [] @@ -111,7 +90,6 @@ def scan_models() -> List[ModelInfo]: available_models = [] available_models.extend(scan_inpaint_models()) available_models.extend(scan_single_file_diffusion_models(DEFAULT_MODEL_DIR)) - cache_dir = Path(DIFFUSERS_CACHE) diffusers_model_names = [] for it in cache_dir.glob("**/*/model_index.json"): diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index d12254e..783a01b 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -65,5 +65,5 @@ class Kandinsky(DiffusionInpaintModel): class Kandinsky22(Kandinsky): - name = "kandinsky2.2" + name = "kandinsky-community/kandinsky-2-2-decoder-inpaint" model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint" diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index 3c2783b..80b9745 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -38,12 +38,6 @@ class PaintByExample(DiffusionInpaintModel): else: self.model = self.model.to(device) - @staticmethod - def download(): - from diffusers import DiffusionPipeline - - DiffusionPipeline.from_pretrained("Fantasy-Studio/Paint-by-Example") - def forward(self, image, mask, config: Config): """Input image and output image have same size image: [H, W, C] RGB diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 9d5ac5f..01f85a3 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -22,20 +22,11 @@ class ModelManager: self.sd_controlnet_method = "" self.model = self.init_model(name, device, **kwargs) - def _map_old_name(self, name: str) -> str: - for old_name, model_cls in models.items(): - if name == old_name and hasattr(model_cls, "model_id_or_path"): - name = model_cls.model_id_or_path - break - return name - @property def current_model(self) -> Dict: - name = self._map_old_name(self.name) return self.available_models[name].model_dump() def init_model(self, name: str, device, **kwargs): - name = self._map_old_name(name) logger.info(f"Loading model: {name}") if name not in self.available_models: raise NotImplementedError(f"Unsupported model: {name}") diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py deleted file mode 100644 index 4c5f412..0000000 --- a/lama_cleaner/parse_args.py +++ /dev/null @@ -1,257 +0,0 @@ -import os -import imghdr -import argparse -from pathlib import Path - -from loguru import logger - -from lama_cleaner.const import * -from lama_cleaner.download import cli_download_model, scan_models -from lama_cleaner.runtime import dump_environment_info - -DOWNLOAD_SUBCOMMAND = "download" - - -def download_parse_args(parser): - subparsers = parser.add_subparsers(dest="subcommand") - subparser = subparsers.add_parser(DOWNLOAD_SUBCOMMAND, help="Download models") - subparser.add_argument( - "--model", help="Erase model name(lama/mat...) or model id on huggingface" - ) - subparser.add_argument( - "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP - ) - - -def parse_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - download_parse_args(parser) - - parser.add_argument("--host", default="127.0.0.1") - parser.add_argument("--port", default=8080, type=int) - - parser.add_argument( - "--config-installer", - action="store_true", - help="Open config web page, mainly for windows installer", - ) - parser.add_argument( - "--load-installer-config", - action="store_true", - help="Load all cmd args from installer config file", - ) - parser.add_argument( - "--installer-config", default=None, help="Config file for windows installer" - ) - - parser.add_argument( - "--model", - default=DEFAULT_MODEL, - help=f"Available models: [{', '.join(AVAILABLE_MODELS)}], or model id on huggingface", - ) - parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP) - parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP) - parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP) - parser.add_argument( - "--sd-cpu-textencoder", action="store_true", help=CPU_TEXTENCODER_HELP - ) - parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) - parser.add_argument( - "--sd-controlnet-method", - default=DEFAULT_SD_CONTROLNET_METHOD, - choices=SD_CONTROLNET_CHOICES, - ) - parser.add_argument( - "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP - ) - parser.add_argument( - "--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES - ) - parser.add_argument("--gui", action="store_true", help=GUI_HELP) - parser.add_argument( - "--gui-size", - default=[1600, 1000], - nargs=2, - type=int, - help="Set window size for GUI", - ) - parser.add_argument("--input", type=str, default=None, help=INPUT_HELP) - parser.add_argument("--output-dir", type=str, default=None, help=OUTPUT_DIR_HELP) - parser.add_argument( - "--model-dir", type=str, default=DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP - ) - parser.add_argument( - "--disable-model-switch", - action="store_true", - help="Disable model switch in frontend", - ) - parser.add_argument( - "--quality", - default=95, - type=int, - help=QUALITY_HELP, - ) - - # Plugins - parser.add_argument( - "--enable-interactive-seg", - action="store_true", - help=INTERACTIVE_SEG_HELP, - ) - parser.add_argument( - "--interactive-seg-model", - default="vit_l", - choices=AVAILABLE_INTERACTIVE_SEG_MODELS, - help=INTERACTIVE_SEG_MODEL_HELP, - ) - parser.add_argument( - "--interactive-seg-device", - default="cpu", - choices=AVAILABLE_INTERACTIVE_SEG_DEVICES, - ) - parser.add_argument( - "--enable-remove-bg", - action="store_true", - help=REMOVE_BG_HELP, - ) - parser.add_argument( - "--enable-anime-seg", - action="store_true", - help=ANIMESEG_HELP, - ) - parser.add_argument( - "--enable-realesrgan", - action="store_true", - help=REALESRGAN_HELP, - ) - parser.add_argument( - "--realesrgan-device", - default="cpu", - type=str, - choices=REALESRGAN_AVAILABLE_DEVICES, - ) - parser.add_argument( - "--realesrgan-model", - default=RealESRGANModelName.realesr_general_x4v3.value, - type=str, - choices=RealESRGANModelNameList, - ) - parser.add_argument( - "--realesrgan-no-half", - action="store_true", - help="Disable half precision for RealESRGAN", - ) - parser.add_argument("--enable-gfpgan", action="store_true", help=GFPGAN_HELP) - parser.add_argument( - "--gfpgan-device", default="cpu", type=str, choices=GFPGAN_AVAILABLE_DEVICES - ) - parser.add_argument( - "--enable-restoreformer", action="store_true", help=RESTOREFORMER_HELP - ) - parser.add_argument( - "--restoreformer-device", - default="cpu", - type=str, - choices=RESTOREFORMER_AVAILABLE_DEVICES, - ) - parser.add_argument( - "--install-plugins-package", - action="store_true", - ) - ######### - - args = parser.parse_args() - # collect system info to help debug - dump_environment_info() - if args.subcommand == DOWNLOAD_SUBCOMMAND: - cli_download_model(args.model, args.model_dir) - return - - if args.install_plugins_package: - from lama_cleaner.installer import install_plugins_package - - install_plugins_package() - exit() - - if args.config_installer: - if args.installer_config is None: - parser.error( - "args.config_installer==True, must set args.installer_config to store config file" - ) - from lama_cleaner.web_config import main - - logger.info("Launching installer web config page") - main(args.installer_config) - exit() - - if args.load_installer_config: - if args.installer_config and not os.path.exists(args.installer_config): - parser.error(f"args.installer_config={args.installer_config} not exists") - - logger.info(f"Loading installer config from {args.installer_config}") - _args = load_config(args.installer_config) - for k, v in vars(_args).items(): - if k in vars(args): - setattr(args, k, v) - - if args.device == "cuda": - import platform - - if platform.system() == "Darwin": - logger.info("MacOS does not support cuda, use cpu instead") - setattr(args, "device", "cpu") - else: - import torch - - if torch.cuda.is_available() is False: - parser.error( - "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" - ) - - os.environ["U2NET_HOME"] = DEFAULT_MODEL_DIR - if args.model_dir and args.model_dir is not None: - if os.path.isfile(args.model_dir): - parser.error(f"invalid --model-dir: {args.model_dir} is a file") - - if not os.path.exists(args.model_dir): - logger.info(f"Create model cache directory: {args.model_dir}") - Path(args.model_dir).mkdir(exist_ok=True, parents=True) - - os.environ["XDG_CACHE_HOME"] = args.model_dir - os.environ["U2NET_HOME"] = args.model_dir - - if args.sd_run_local or args.local_files_only: - os.environ["TRANSFORMERS_OFFLINE"] = "1" - os.environ["HF_HUB_OFFLINE"] = "1" - - if args.model not in AVAILABLE_MODELS: - scanned_models = scan_models() - if args.model not in [it.name for it in scanned_models]: - parser.error( - f"invalid --model: {args.model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}" - ) - - if args.input and args.input is not None: - if not os.path.exists(args.input): - parser.error(f"invalid --input: {args.input} not exists") - if os.path.isfile(args.input): - if imghdr.what(args.input) is None: - parser.error(f"invalid --input: {args.input} is not a valid image file") - else: - if args.output_dir is None: - parser.error( - f"invalid --input: {args.input} is a directory, --output-dir is required" - ) - - if args.output_dir is not None: - output_dir = Path(args.output_dir) - if not output_dir.exists(): - logger.info(f"Creating output directory: {output_dir}") - output_dir.mkdir(parents=True) - else: - if not output_dir.is_dir(): - parser.error(f"invalid --output-dir: {output_dir} is not a directory") - - return args diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 4fb42f5..052cf06 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -585,7 +585,7 @@ def start( port: int = Option(8080), model: str = Option( DEFAULT_MODEL, - help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. " + help=f"Available erase models: [{', '.join(AVAILABLE_MODELS)}]. " f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface", ), model_dir: Path = Option( @@ -644,13 +644,12 @@ def start( os.environ["TRANSFORMERS_OFFLINE"] = "1" os.environ["HF_HUB_OFFLINE"] = "1" - if model not in AVAILABLE_MODELS: - scanned_models = scan_models() - if model not in [it.name for it in scanned_models]: - logger.error( - f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}" - ) - exit() + scanned_models = scan_models() + if model not in [it.name for it in scanned_models]: + logger.error( + f"invalid model: {model} not exists. Available models: {[it.name for it in scanned_models]}" + ) + exit() global_config.image_quality = quality global_config.disable_model_switch = disable_model_switch diff --git a/requirements.txt b/requirements.txt index 016a179..7f6e668 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -torch>=1.9.0 +torch>=2.0.0 +typer opencv-python flask==2.2.3 flask-socketio diff --git a/web_app/src/App.tsx b/web_app/src/App.tsx index 76a5a10..c591a5c 100644 --- a/web_app/src/App.tsx +++ b/web_app/src/App.tsx @@ -3,7 +3,7 @@ import { nanoid } from "nanoid" import useInputImage from "@/hooks/useInputImage" import { keepGUIAlive } from "@/lib/utils" -import { getServerConfig, isDesktop } from "@/lib/api" +import { getServerConfig } from "@/lib/api" import Header from "@/components/Header" import Workspace from "@/components/Workspace" import FileSelect from "@/components/FileSelect" @@ -40,21 +40,14 @@ function Home() { updateAppState({ windowSize }) }, [windowSize]) - // Keeping GUI Window Open - useEffect(() => { - const fetchData = async () => { - const isRunDesktop = await isDesktop().then((res) => res.text()) - if (isRunDesktop === "True") { - keepGUIAlive() - } - } - fetchData() - }, []) - useEffect(() => { const fetchServerConfig = async () => { const serverConfig = await getServerConfig().then((res) => res.json()) setServerConfig(serverConfig) + if (serverConfig.isDesktop) { + // Keeping GUI Window Open + keepGUIAlive() + } } fetchServerConfig() }, []) diff --git a/web_app/src/components/Editor.tsx b/web_app/src/components/Editor.tsx index 9ef276e..50f256c 100644 --- a/web_app/src/components/Editor.tsx +++ b/web_app/src/components/Editor.tsx @@ -382,6 +382,9 @@ export default function Editor(props: EditorProps) { } const onPointerUp = (ev: SyntheticEvent) => { + if (!hadDrawSomething()) { + return + } if (isMidClick(ev)) { setIsPanning(false) } diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index 14d69b9..1a18b64 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -1,7 +1,7 @@ import { IconButton } from "@/components/ui/button" import { useToggle } from "@uidotdev/usehooks" import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "./ui/dialog" -import { Info, Settings } from "lucide-react" +import { HelpCircle, Settings } from "lucide-react" import { zodResolver } from "@hookform/resolvers/zod" import { useForm } from "react-hook-form" import * as z from "zod" @@ -19,7 +19,7 @@ import { import { Input } from "@/components/ui/input" import { Switch } from "./ui/switch" import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs" -import { useState } from "react" +import { useEffect, useState } from "react" import { cn } from "@/lib/utils" import { useQuery } from "@tanstack/react-query" import { fetchModelInfos, switchModel } from "@/lib/api" @@ -30,7 +30,6 @@ import { useToast } from "./ui/use-toast" import { AlertDialog, AlertDialogContent, - AlertDialogDescription, AlertDialogHeader, } from "./ui/alert-dialog" import { @@ -85,6 +84,9 @@ export function SettingsDialog() { ]) const { toast } = useToast() const [model, setModel] = useState(settings.model) + useEffect(() => { + setModel(settings.model) + }, [settings.model]) const { data: modelInfos, status } = useQuery({ queryKey: ["modelInfos"], @@ -163,7 +165,6 @@ export function SettingsDialog() { } function onModelSelect(info: ModelInfo) { - console.log(info) setModel(info) } @@ -211,35 +212,35 @@ export function SettingsDialog() { } return ( -
+
-
Current Model
+
Current Model
{model.name}
-
-
Available models
- - - +
+
Available models
+ {/* + + */}
- Inpaint + Erase - Diffusion + Stable Diffusion - Diffusion inpaint + Stable Diffusion Inpaint - Diffusion other + Other Diffusion - + {renderModelList([MODEL_TYPE_INPAINT])} @@ -267,7 +268,7 @@ export function SettingsDialog() { function renderGeneralSettings() { return ( -
+
Enable manual inpainting - Click a button to trigger inpainting after draw mask. + For erase model, click a button to trigger inpainting after + draw mask.
@@ -468,7 +470,7 @@ export function SettingsDialog() {
-
+
{tab === TAB_MODEL ? renderModelSettings() : <>} {tab === TAB_GENERAL ? renderGeneralSettings() : <>} diff --git a/web_app/src/components/SidePanel/CV2Options.tsx b/web_app/src/components/SidePanel/CV2Options.tsx new file mode 100644 index 0000000..e267941 --- /dev/null +++ b/web_app/src/components/SidePanel/CV2Options.tsx @@ -0,0 +1,77 @@ +import { useStore } from "@/lib/states" +import { LabelTitle, RowContainer } from "./LabelTitle" +import { NumberInput } from "../ui/input" +import { Slider } from "../ui/slider" +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from "../ui/select" +import { CV2Flag } from "@/lib/types" + +const CV2Options = () => { + const [settings, updateSettings] = useStore((state) => [ + state.settings, + state.updateSettings, + ]) + + return ( +
+ + + + + + + updateSettings({ cv2Radius: vals[0] })} + /> + { + updateSettings({ cv2Radius: val }) + }} + /> + +
+ ) +} + +export default CV2Options diff --git a/web_app/src/components/SidePanel.tsx b/web_app/src/components/SidePanel/DiffusionOptions.tsx similarity index 55% rename from web_app/src/components/SidePanel.tsx rename to web_app/src/components/SidePanel/DiffusionOptions.tsx index 535f1e7..eac00de 100644 --- a/web_app/src/components/SidePanel.tsx +++ b/web_app/src/components/SidePanel/DiffusionOptions.tsx @@ -1,9 +1,7 @@ -import { FormEvent, useState } from "react" -import { useToggle } from "react-use" +import { FormEvent } from "react" import { useStore } from "@/lib/states" -import { Switch } from "./ui/switch" -import { Label } from "./ui/label" -import { NumberInput } from "./ui/input" +import { Switch } from "../ui/switch" +import { NumberInput } from "../ui/input" import { Select, SelectContent, @@ -11,56 +9,28 @@ import { SelectItem, SelectTrigger, SelectValue, -} from "./ui/select" -import { Textarea } from "./ui/textarea" +} from "../ui/select" +import { Textarea } from "../ui/textarea" import { SDSampler } from "@/lib/types" -import { Separator } from "./ui/separator" -import { ScrollArea } from "./ui/scroll-area" -import { Sheet, SheetContent, SheetHeader, SheetTrigger } from "./ui/sheet" -import { - ArrowDownFromLine, - ArrowLeftFromLine, - ArrowRightFromLine, - ArrowUpFromLine, - ChevronLeft, - ChevronRight, - HelpCircle, - LucideIcon, - Maximize, - Move, - MoveHorizontal, - MoveVertical, - Upload, -} from "lucide-react" -import { Button, ImageUploadButton } from "./ui/button" -import useHotKey from "@/hooks/useHotkey" -import { Slider } from "./ui/slider" +import { Separator } from "../ui/separator" +import { Move, MoveHorizontal, MoveVertical, Upload } from "lucide-react" +import { Button, ImageUploadButton } from "../ui/button" +import { Slider } from "../ui/slider" import { useImage } from "@/hooks/useImage" import { EXTENDER_ALL, - EXTENDER_BUILTIN_ALL, - EXTENDER_BUILTIN_X_LEFT, - EXTENDER_BUILTIN_X_RIGHT, - EXTENDER_BUILTIN_Y_BOTTOM, - EXTENDER_BUILTIN_Y_TOP, EXTENDER_X, EXTENDER_Y, INSTRUCT_PIX2PIX, PAINT_BY_EXAMPLE, } from "@/lib/const" -import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs" -import { Tooltip, TooltipContent, TooltipTrigger } from "./ui/tooltip" - -const RowContainer = ({ children }: { children: React.ReactNode }) => ( -
{children}
-) +import { Tabs, TabsContent, TabsList, TabsTrigger } from "../ui/tabs" +import { RowContainer, LabelTitle } from "./LabelTitle" const ExtenderButton = ({ - IconCls, text, onClick, }: { - IconCls: LucideIcon text: string onClick: () => void }) => { @@ -73,92 +43,32 @@ const ExtenderButton = ({ disabled={!showExtender} onClick={onClick} > -
- - {text} -
+
{text}
) } -const LabelTitle = ({ - text, - toolTip, - url, - htmlFor, - disabled = false, -}: { - text: string - toolTip?: string - url?: string - htmlFor?: string - disabled?: boolean -}) => { - return ( - - - - - {toolTip ? ( - -

{toolTip}

- {url ? ( - - ) : ( - <> - )} -
- ) : ( - <> - )} -
- ) -} - -const SidePanel = () => { +const DiffusionOptions = () => { const [ settings, - windowSize, paintByExampleFile, isProcessing, updateSettings, - showSidePanel, runInpainting, updateAppState, updateExtenderByBuiltIn, updateExtenderDirection, ] = useStore((state) => [ state.settings, - state.windowSize, state.paintByExampleFile, state.getIsProcessing(), state.updateSettings, - state.showSidePanel(), state.runInpainting, state.updateAppState, state.updateExtenderByBuiltIn, state.updateExtenderDirection, ]) const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile) - const [open, toggleOpen] = useToggle(true) - - useHotKey("c", () => { - toggleOpen() - }) - - if (!showSidePanel) { - return null - } const onKeyUp = (e: React.KeyboardEvent) => { // negativePrompt 回车触发 inpainting @@ -582,32 +492,20 @@ const SidePanel = () => { className="flex gap-2 justify-center mt-0" > updateExtenderByBuiltIn(EXTENDER_X, 1.25)} + /> + - updateExtenderByBuiltIn(EXTENDER_BUILTIN_X_LEFT, 1.5) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 1.5)} + /> + updateExtenderByBuiltIn(EXTENDER_X, 1.75)} /> - updateExtenderByBuiltIn(EXTENDER_BUILTIN_X_LEFT, 2.0) - } - /> - - updateExtenderByBuiltIn(EXTENDER_BUILTIN_X_RIGHT, 1.5) - } - /> - - updateExtenderByBuiltIn(EXTENDER_BUILTIN_X_RIGHT, 2.0) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_X, 2.0)} /> { className="flex gap-2 justify-center mt-0" > updateExtenderByBuiltIn(EXTENDER_Y, 1.25)} + /> + - updateExtenderByBuiltIn(EXTENDER_BUILTIN_Y_TOP, 1.5) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 1.5)} + /> + updateExtenderByBuiltIn(EXTENDER_Y, 1.75)} /> - updateExtenderByBuiltIn(EXTENDER_BUILTIN_Y_TOP, 2.0) - } - /> - - updateExtenderByBuiltIn(EXTENDER_BUILTIN_Y_BOTTOM, 1.5) - } - /> - - updateExtenderByBuiltIn(EXTENDER_BUILTIN_Y_BOTTOM, 2.0) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_Y, 2.0)} /> { className="flex gap-2 justify-center mt-0" > - updateExtenderByBuiltIn(EXTENDER_BUILTIN_ALL, 1.25) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.25)} /> - updateExtenderByBuiltIn(EXTENDER_BUILTIN_ALL, 1.5) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.5)} /> - updateExtenderByBuiltIn(EXTENDER_BUILTIN_ALL, 1.75) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 1.75)} /> - updateExtenderByBuiltIn(EXTENDER_BUILTIN_ALL, 2.0) - } + onClick={() => updateExtenderByBuiltIn(EXTENDER_ALL, 2.0)} /> @@ -684,247 +558,194 @@ const SidePanel = () => { } return ( - - -
- - - + {renderPaintByExample()} +
) } -export default SidePanel +export default DiffusionOptions diff --git a/web_app/src/components/SidePanel/LDMOptions.tsx b/web_app/src/components/SidePanel/LDMOptions.tsx new file mode 100644 index 0000000..281372c --- /dev/null +++ b/web_app/src/components/SidePanel/LDMOptions.tsx @@ -0,0 +1,77 @@ +import { useStore } from "@/lib/states" +import { LabelTitle, RowContainer } from "./LabelTitle" +import { NumberInput } from "../ui/input" +import { Slider } from "../ui/slider" +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from "../ui/select" +import { LDMSampler } from "@/lib/types" + +const LDMOptions = () => { + const [settings, updateSettings] = useStore((state) => [ + state.settings, + state.updateSettings, + ]) + + return ( +
+
+ + + updateSettings({ ldmSteps: vals[0] })} + /> + { + updateSettings({ ldmSteps: val }) + }} + /> + +
+ + + + +
+ ) +} + +export default LDMOptions diff --git a/web_app/src/components/SidePanel/LabelTitle.tsx b/web_app/src/components/SidePanel/LabelTitle.tsx new file mode 100644 index 0000000..09c8ebc --- /dev/null +++ b/web_app/src/components/SidePanel/LabelTitle.tsx @@ -0,0 +1,53 @@ +import { Button } from "../ui/button" +import { Label } from "../ui/label" +import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip" + +const RowContainer = ({ children }: { children: React.ReactNode }) => ( +
{children}
+) + +const LabelTitle = ({ + text, + toolTip = "", + url, + htmlFor, + disabled = false, +}: { + text: string + toolTip?: string + url?: string + htmlFor?: string + disabled?: boolean +}) => { + return ( + + + + + {toolTip || url ? ( + +

{toolTip}

+ {url ? ( + + ) : ( + <> + )} +
+ ) : ( + <> + )} +
+ ) +} + +export { LabelTitle, RowContainer } diff --git a/web_app/src/components/SidePanel/index.tsx b/web_app/src/components/SidePanel/index.tsx new file mode 100644 index 0000000..e8e1d9a --- /dev/null +++ b/web_app/src/components/SidePanel/index.tsx @@ -0,0 +1,98 @@ +import { useToggle } from "react-use" +import { useStore } from "@/lib/states" +import { Separator } from "../ui/separator" +import { ScrollArea } from "../ui/scroll-area" +import { Sheet, SheetContent, SheetHeader, SheetTrigger } from "../ui/sheet" +import { ChevronLeft, ChevronRight } from "lucide-react" +import { Button } from "../ui/button" +import useHotKey from "@/hooks/useHotkey" +import { RowContainer } from "./LabelTitle" +import { CV2, LDM, MODEL_TYPE_INPAINT } from "@/lib/const" +import LDMOptions from "./LDMOptions" +import DiffusionOptions from "./DiffusionOptions" +import CV2Options from "./CV2Options" + +const SidePanel = () => { + const [settings, windowSize] = useStore((state) => [ + state.settings, + state.windowSize, + ]) + + const [open, toggleOpen] = useToggle(true) + + useHotKey("c", () => { + toggleOpen() + }) + + if ( + settings.model.name !== LDM && + settings.model.name !== CV2 && + settings.model.model_type === MODEL_TYPE_INPAINT + ) { + return null + } + + const renderSidePanelOptions = () => { + if (settings.model.name === LDM) { + return + } + if (settings.model.name === CV2) { + return + } + return + } + + return ( + + + + + event.preventDefault()} + onPointerDownOutside={(event) => event.preventDefault()} + > + + +
+ { + settings.model.name.split("/")[ + settings.model.name.split("/").length - 1 + ] + } +
+ +
+ +
+ + {renderSidePanelOptions()} + +
+
+ ) +} + +export default SidePanel diff --git a/web_app/src/components/Workspace.tsx b/web_app/src/components/Workspace.tsx index 5ed63e9..909bb10 100644 --- a/web_app/src/components/Workspace.tsx +++ b/web_app/src/components/Workspace.tsx @@ -1,23 +1,11 @@ import { useEffect } from "react" import Editor from "./Editor" -import { - AIModel, - isPaintByExampleState, - isPix2PixState, - isSDState, -} from "@/lib/store" import { currentModel } from "@/lib/api" import { useStore } from "@/lib/states" import ImageSize from "./ImageSize" import Plugins from "./Plugins" import { InteractiveSeg } from "./InteractiveSeg" import SidePanel from "./SidePanel" -// import SidePanel from "./SidePanel/SidePanel" -// import PESidePanel from "./SidePanel/PESidePanel" -// import P2PSidePanel from "./SidePanel/P2PSidePanel" -// import Plugins from "./Plugins/Plugins" -// import Flex from "./shared/Layout" -// import ImageSize from "./ImageSize/ImageSize" const Workspace = () => { const [file, updateSettings] = useStore((state) => [ @@ -35,10 +23,6 @@ const Workspace = () => { return ( <> - {/* {isSD ? : <>} - {isPaintByExample ? : <>} - {isPix2Pix ? : <>} - {/* */}
diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index 3b09f41..151cd57 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -125,12 +125,6 @@ export function fetchModelInfos(): Promise { return api.get("/models").then((response) => response.data) } -export function isDesktop() { - return fetch(`${API_ENDPOINT}/is_desktop`, { - method: "GET", - }) -} - export function modelDownloaded(name: string) { return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, { method: "GET", diff --git a/web_app/src/lib/const.ts b/web_app/src/lib/const.ts index 1a782ad..4291381 100644 --- a/web_app/src/lib/const.ts +++ b/web_app/src/lib/const.ts @@ -11,12 +11,9 @@ export const BRUSH_COLOR = "#ffcc00bb" export const EXTENDER_X = "extender_x" export const EXTENDER_Y = "extender_y" export const EXTENDER_ALL = "extender_all" -export const EXTENDER_BUILTIN_X_LEFT = "extender_builtin_x_left" -export const EXTENDER_BUILTIN_X_RIGHT = "extender_builtin_x_right" -export const EXTENDER_BUILTIN_Y_TOP = "extender_builtin_y_top" -export const EXTENDER_BUILTIN_Y_BOTTOM = "extender_builtin_y_bottom" -export const EXTENDER_BUILTIN_ALL = "extender_builtin_all" +export const LDM = "ldm" +export const CV2 = "cv2" export const PAINT_BY_EXAMPLE = "Fantasy-Studio/Paint-by-Example" export const INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix" export const KANDINSKY_2_2 = "kandinsky-community/kandinsky-2-2-decoder-inpaint" diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index 52904c2..559e5cb 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -22,11 +22,6 @@ import { DEFAULT_BRUSH_SIZE, DEFAULT_NEGATIVE_PROMPT, EXTENDER_ALL, - EXTENDER_BUILTIN_ALL, - EXTENDER_BUILTIN_X_LEFT, - EXTENDER_BUILTIN_X_RIGHT, - EXTENDER_BUILTIN_Y_BOTTOM, - EXTENDER_BUILTIN_Y_TOP, EXTENDER_X, EXTENDER_Y, MODEL_TYPE_INPAINT, @@ -112,6 +107,7 @@ type ServerConfig = { enableAutoSaving: boolean enableControlnet: boolean controlnetMethod: string + isDesktop: boolean } type InteractiveSegState = { @@ -129,6 +125,7 @@ type EditorState = { lineGroups: LineGroup[] lastLineGroup: LineGroup curLineGroup: LineGroup + // 只用来显示 extraMasks: HTMLImageElement[] // redo 相关 redoRenders: HTMLImageElement[] @@ -153,7 +150,7 @@ type AppState = { cropperState: CropperState extenderState: CropperState - isCropperExtenderResizing: bool + isCropperExtenderResizing: boolean serverConfig: ServerConfig @@ -194,7 +191,6 @@ type AppAction = { resetInteractiveSegState: () => void handleInteractiveSegAccept: () => void showPromptInput: () => boolean - showSidePanel: () => boolean runInpainting: () => Promise showPrevMask: () => Promise @@ -281,6 +277,7 @@ const defaultValues: AppState = { enableAutoSaving: false, enableControlnet: false, controlnetMethod: "lllyasviel/control_v11p_sd15_canny", + isDesktop: false, }, settings: { model: { @@ -334,6 +331,9 @@ export const useStore = createWithEqualityFn()( ...defaultValues, showPrevMask: async () => { + if (get().settings.showExtender) { + return + } const { lastLineGroup, curLineGroup } = get().editorState const { prevInteractiveSegMask, interactiveSegMask } = get().interactiveSegState @@ -380,7 +380,7 @@ export const useStore = createWithEqualityFn()( } return targetFile }, - // todo: 传入 custom mask,单独逻辑 + runInpainting: async () => { const { isInpainting, @@ -399,6 +399,14 @@ export const useStore = createWithEqualityFn()( if (file === null) { return } + if ( + settings.showExtender && + extenderState.height === imageHeight && + extenderState.width === imageWidth + ) { + return + } + const { lastLineGroup, curLineGroup, lineGroups, renders } = get().editorState @@ -406,42 +414,33 @@ export const useStore = createWithEqualityFn()( get().interactiveSegState const useLastLineGroup = - curLineGroup.length === 0 && interactiveSegMask === null - - const maskImage = useLastLineGroup - ? prevInteractiveSegMask - : interactiveSegMask + curLineGroup.length === 0 && + interactiveSegMask === null && + !settings.showExtender // useLastLineGroup 的影响 // 1. 使用上一次的 mask // 2. 结果替换当前 render + let maskImage = null let maskLineGroup: LineGroup = [] if (useLastLineGroup === true) { - if ( - lastLineGroup.length === 0 && - maskImage === null && - !settings.showExtender - ) { - toast({ - variant: "destructive", - description: "Please draw mask on picture", - }) - return - } maskLineGroup = lastLineGroup + maskImage = prevInteractiveSegMask } else { - if ( - curLineGroup.length === 0 && - maskImage === null && - !settings.showExtender - ) { - toast({ - variant: "destructive", - description: "Please draw mask on picture", - }) - return - } maskLineGroup = curLineGroup + maskImage = interactiveSegMask + } + + if ( + maskLineGroup.length === 0 && + maskImage === null && + !settings.showExtender + ) { + toast({ + variant: "destructive", + description: "Please draw mask on picture", + }) + return } const newLineGroups = [...lineGroups, maskLineGroup] @@ -498,6 +497,7 @@ export const useStore = createWithEqualityFn()( const newRender = new Image() await loadImage(newRender, blob) const newRenders = [...renders, newRender] + get().setImageSize(newRender.width, newRender.height) get().updateEditorState({ renders: newRenders, lineGroups: newLineGroups, @@ -545,7 +545,7 @@ export const useStore = createWithEqualityFn()( const { blob } = res const newRender = new Image() await loadImage(newRender, blob) - get().setImageSize(newRender.height, newRender.width) + get().setImageSize(newRender.width, newRender.height) const newRenders = [...renders, newRender] const newLineGroups = [...lineGroups, []] get().updateEditorState({ @@ -739,11 +739,6 @@ export const useStore = createWithEqualityFn()( ) }, - showSidePanel: (): boolean => { - const model = get().settings.model - return model.model_type !== MODEL_TYPE_INPAINT - }, - setServerConfig: (newValue: ServerConfig) => { set((state) => { state.serverConfig = newValue @@ -910,6 +905,7 @@ export const useStore = createWithEqualityFn()( state.extenderState.width = state.imageWidth state.extenderState.height = state.imageHeight }) + get().updateExtenderByBuiltIn(newValue, 1.5) }, updateExtenderByBuiltIn: (direction: string, scale: number) => { @@ -920,21 +916,15 @@ export const useStore = createWithEqualityFn()( height = imageHeight switch (direction) { - case EXTENDER_BUILTIN_X_LEFT: - x = -Math.ceil(imageWidth * (scale - 1)) + case EXTENDER_X: + x = -Math.ceil((imageWidth * (scale - 1)) / 2) width = Math.ceil(imageWidth * scale) break - case EXTENDER_BUILTIN_X_RIGHT: - width = Math.ceil(imageWidth * scale) - break - case EXTENDER_BUILTIN_Y_TOP: - y = -Math.ceil(imageHeight * (scale - 1)) + case EXTENDER_Y: + y = -Math.ceil((imageHeight * (scale - 1)) / 2) height = Math.ceil(imageHeight * scale) break - case EXTENDER_BUILTIN_Y_BOTTOM: - height = Math.ceil(imageHeight * scale) - break - case EXTENDER_BUILTIN_ALL: + case EXTENDER_ALL: x = -Math.ceil((imageWidth * (scale - 1)) / 2) y = -Math.ceil((imageHeight * (scale - 1)) / 2) width = Math.ceil(imageWidth * scale)