From f7e1e073dcf01d5784b4373923d93874aff22f3f Mon Sep 17 00:00:00 2001 From: Sanster Date: Sun, 17 Apr 2022 23:31:12 +0800 Subject: [PATCH] make model switch work with toast --- lama_cleaner/app/package.json | 1 + lama_cleaner/app/src/adapters/inpainting.ts | 18 +++- .../src/components/Settings/SettingsModal.tsx | 14 +-- lama_cleaner/app/src/components/Workspace.tsx | 85 ++++++++++++++++++- .../app/src/components/shared/Toast.scss | 83 ++++++++++++++++++ .../app/src/components/shared/Toast.tsx | 81 ++++++++++++++++++ lama_cleaner/app/src/store/Atoms.tsx | 44 +++++++--- lama_cleaner/app/src/styles/_Animations.scss | 18 ++++ lama_cleaner/app/src/styles/_Colors.scss | 4 + lama_cleaner/app/src/styles/_ColorsDark.scss | 1 + lama_cleaner/app/src/styles/_index.scss | 1 + lama_cleaner/app/yarn.lock | 73 ++++++++++++++++ lama_cleaner/helper.py | 7 +- lama_cleaner/model/base.py | 5 ++ lama_cleaner/model/lama.py | 8 +- lama_cleaner/model/ldm.py | 12 ++- lama_cleaner/model_manager.py | 8 ++ main.py | 12 ++- 18 files changed, 447 insertions(+), 28 deletions(-) create mode 100644 lama_cleaner/app/src/components/shared/Toast.scss create mode 100644 lama_cleaner/app/src/components/shared/Toast.tsx diff --git a/lama_cleaner/app/package.json b/lama_cleaner/app/package.json index 13095a8..3784d4f 100644 --- a/lama_cleaner/app/package.json +++ b/lama_cleaner/app/package.json @@ -6,6 +6,7 @@ "dependencies": { "@heroicons/react": "^1.0.4", "@radix-ui/react-switch": "^0.1.5", + "@radix-ui/react-toast": "^0.1.1", "@testing-library/jest-dom": "^5.14.1", "@testing-library/react": "^12.1.2", "@testing-library/user-event": "^13.5.0", diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index d2d1028..cac4bb4 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -1,4 +1,4 @@ -import { Setting } from '../store/Atoms' +import { Settings } from '../store/Atoms' import { dataURItoBlob } from '../utils' export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` @@ -6,7 +6,7 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` export default async function inpaint( imageFile: File, maskBase64: string, - settings: Setting, + settings: Settings, sizeLimit?: string ) { // 1080, 2000, Original @@ -43,8 +43,20 @@ export default async function inpaint( export function switchModel(name: string) { const fd = new FormData() fd.append('name', name) - return fetch(`${API_ENDPOINT}/switch_model`, { + return fetch(`${API_ENDPOINT}/model`, { method: 'POST', body: fd, }) } + +export function currentModel() { + return fetch(`${API_ENDPOINT}/model`, { + method: 'GET', + }) +} + +export function modelDownloaded(name: string) { + return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, { + method: 'GET', + }) +} diff --git a/lama_cleaner/app/src/components/Settings/SettingsModal.tsx b/lama_cleaner/app/src/components/Settings/SettingsModal.tsx index c264e80..d5f9435 100644 --- a/lama_cleaner/app/src/components/Settings/SettingsModal.tsx +++ b/lama_cleaner/app/src/components/Settings/SettingsModal.tsx @@ -1,26 +1,28 @@ import React from 'react' import { useRecoilState } from 'recoil' -import { switchModel } from '../../adapters/inpainting' import { settingState } from '../../store/Atoms' import Modal from '../shared/Modal' import HDSettingBlock from './HDSettingBlock' import ModelSettingBlock from './ModelSettingBlock' -export default function SettingModal() { +interface SettingModalProps { + onClose: () => void +} +export default function SettingModal(props: SettingModalProps) { + const { onClose } = props const [setting, setSettingState] = useRecoilState(settingState) - const onClose = () => { + const handleOnClose = () => { setSettingState(old => { return { ...old, show: false } }) - - switchModel(setting.model) + onClose() } return ( { + const [settings, setSettingState] = useRecoilState(settingState) + const [toastVal, setToastState] = useRecoilState(toastState) + + const onSettingClose = async () => { + const curModel = await currentModel().then(res => res.text()) + if (curModel === settings.model) { + return + } + const downloaded = await modelDownloaded(settings.model).then(res => + res.text() + ) + + const { model } = settings + + let loadingMessage = `Switching to ${model} model` + let loadingDuration = 3000 + if (downloaded === 'False') { + loadingMessage = `Downloading ${model} model, this may take a while` + loadingDuration = 9999999999 + } + + setToastState({ + open: true, + desc: loadingMessage, + state: 'loading', + duration: loadingDuration, + }) + + switchModel(model) + .then(res => { + if (res.ok) { + setToastState({ + open: true, + desc: `Switch to ${model} model success`, + state: 'success', + duration: 3000, + }) + } else { + throw new Error('Server error') + } + }) + .catch(() => { + setToastState({ + open: true, + desc: `Switch to ${model} model failed`, + state: 'error', + duration: 3000, + }) + setSettingState(old => { + return { ...old, model: curModel as AIModel } + }) + }) + } + + useEffect(() => { + currentModel() + .then(res => res.text()) + .then(model => { + setSettingState(old => { + return { ...old, model: model as AIModel } + }) + }) + }, []) + return ( <> - + + { + setToastState(old => { + return { ...old, open } + }) + }} + /> ) } diff --git a/lama_cleaner/app/src/components/shared/Toast.scss b/lama_cleaner/app/src/components/shared/Toast.scss new file mode 100644 index 0000000..828f744 --- /dev/null +++ b/lama_cleaner/app/src/components/shared/Toast.scss @@ -0,0 +1,83 @@ +.toast-viewpoint { + position: fixed; + top: 48px; + right: 0; + display: flex; + flex-direction: row; + padding: 25px; + gap: 10px; + max-width: 100vw; + margin: 0; + z-index: 999999; + + &:focus-visible { + outline: none; + } +} + +.toast-root { + border: 1px solid var(--border-color-light); + background-color: var(--page-bg); + border-radius: 0.6rem; + padding: 15px; + display: flex; + align-items: center; + + gap: 12px; + + &[data-state='open'] { + animation: slideIn 150ms cubic-bezier(0.16, 1, 0.3, 1); + } + + &[data-state='close'] { + animation: opacityReveal 100ms ease-in forwards; + } + + &[data-state='cancel'] { + transform: translateX(0); + animation: transform 100ms ease-out; + } + + &.error { + border: 1px solid var(--error-color); + } + + &.success { + border: 1px solid var(--success-color); + } +} + +.error-icon { + height: 24px; + width: 24px; + color: var(--error-color); +} + +.success-icon { + height: 24px; + width: 24px; + color: var(--success-color); +} + +.loading-icon { + display: flex; + align-items: center; + animation-name: spin; + animation-duration: 1500ms; + animation-iteration-count: infinite; + transform-origin: center center; + animation-timing-function: linear; +} + +.toast-icon { + display: flex; + align-items: center; +} + +.toast-desc { + display: flex; + align-items: center; + margin: 0; + color: var(--text-color); + min-width: 240px; +} diff --git a/lama_cleaner/app/src/components/shared/Toast.tsx b/lama_cleaner/app/src/components/shared/Toast.tsx new file mode 100644 index 0000000..b1c9f0a --- /dev/null +++ b/lama_cleaner/app/src/components/shared/Toast.tsx @@ -0,0 +1,81 @@ +import * as React from 'react' +import * as ToastPrimitive from '@radix-ui/react-toast' +import { ToastProps } from '@radix-ui/react-toast' +import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline' + +const LoadingIcon = () => { + return ( + + + + + + + + + + + + + ) +} + +export type ToastState = 'default' | 'error' | 'loading' | 'success' + +interface MyToastProps extends ToastProps { + desc: string + state?: ToastState +} + +const Toast = React.forwardRef< + React.ElementRef, + MyToastProps +>((props, forwardedRef) => { + const { state, desc, ...itemProps } = props + + const getIcon = () => { + switch (state) { + case 'error': + return + case 'success': + return + case 'loading': + return + default: + return <> + } + } + + return ( + + +
{getIcon()}
+ + {desc} + +
+ +
+ ) +}) + +Toast.defaultProps = { + desc: '', + state: 'loading', +} + +export default Toast diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index bad9e43..99181e4 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -1,18 +1,36 @@ import { atom } from 'recoil' import { HDStrategy } from '../components/Settings/HDSettingBlock' import { AIModel } from '../components/Settings/ModelSettingBlock' +import { ToastState } from '../components/shared/Toast' export const fileState = atom({ key: 'fileState', default: undefined, }) +interface ToastAtomState { + open: boolean + desc: string + state: ToastState + duration: number +} + +export const toastState = atom({ + key: 'toastState', + default: { + open: false, + desc: '', + state: 'default', + duration: 3000, + }, +}) + export const shortcutsState = atom({ key: 'shortcutsState', default: false, }) -export interface Setting { +export interface Settings { show: boolean saveImageBesideOrigin: boolean model: AIModel @@ -27,16 +45,18 @@ export interface Setting { ldmSteps: number } -export const settingState = atom({ +export const settingStateDefault = { + show: false, + saveImageBesideOrigin: false, + model: AIModel.LAMA, + ldmSteps: 50, + hdStrategy: HDStrategy.RESIZE, + hdStrategyResizeLimit: 2048, + hdStrategyCropTrigerSize: 2048, + hdStrategyCropMargin: 128, +} + +export const settingState = atom({ key: 'settingsState', - default: { - show: false, - saveImageBesideOrigin: false, - model: AIModel.LAMA, - hdStrategy: HDStrategy.RESIZE, - hdStrategyResizeLimit: 2048, - hdStrategyCropTrigerSize: 2048, - hdStrategyCropMargin: 128, - ldmSteps: 50, - }, + default: settingStateDefault, }) diff --git a/lama_cleaner/app/src/styles/_Animations.scss b/lama_cleaner/app/src/styles/_Animations.scss index 0f71aac..17a2a7e 100644 --- a/lama_cleaner/app/src/styles/_Animations.scss +++ b/lama_cleaner/app/src/styles/_Animations.scss @@ -37,3 +37,21 @@ transform: translateY(0); } } + +@keyframes slideIn { + 0% { + transform: translateX(calc(100% + 25px)); + } + 100% { + transform: translateX(0); + } +} + +@keyframes spin { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} diff --git a/lama_cleaner/app/src/styles/_Colors.scss b/lama_cleaner/app/src/styles/_Colors.scss index 468735a..5c229a9 100644 --- a/lama_cleaner/app/src/styles/_Colors.scss +++ b/lama_cleaner/app/src/styles/_Colors.scss @@ -7,6 +7,10 @@ --yellow-accent: #ffcc00; --link-color: rgb(0, 0, 0); --border-color: rgb(100, 100, 120); + --border-color-light: rgba(100, 100, 120, 0.5); + + --error-color: rgb(239, 68, 68); + --success-color: rgb(16, 185, 129); // Editor --editor-toolkit-bg: rgba(255, 255, 255, 0.5); diff --git a/lama_cleaner/app/src/styles/_ColorsDark.scss b/lama_cleaner/app/src/styles/_ColorsDark.scss index 6a48fc3..c450346 100644 --- a/lama_cleaner/app/src/styles/_ColorsDark.scss +++ b/lama_cleaner/app/src/styles/_ColorsDark.scss @@ -7,6 +7,7 @@ --yellow-accent: #ffcc00; --link-color: var(--yellow-accent); --border-color: rgb(100, 100, 120); + --border-color-light: rgba(102, 102, 102); // Editor --editor-toolkit-bg: rgba(0, 0, 0, 0.5); diff --git a/lama_cleaner/app/src/styles/_index.scss b/lama_cleaner/app/src/styles/_index.scss index d6e4210..e277d5f 100644 --- a/lama_cleaner/app/src/styles/_index.scss +++ b/lama_cleaner/app/src/styles/_index.scss @@ -20,6 +20,7 @@ @use '../components/shared/Selector'; @use '../components/shared/Switch'; @use '../components/shared/NumberInput'; +@use '../components/shared/Toast'; // Main CSS *, diff --git a/lama_cleaner/app/yarn.lock b/lama_cleaner/app/yarn.lock index 153b58d..c192280 100644 --- a/lama_cleaner/app/yarn.lock +++ b/lama_cleaner/app/yarn.lock @@ -1565,6 +1565,19 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-dismissable-layer@0.1.5": + version "0.1.5" + resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-0.1.5.tgz#9379032351e79028d472733a5cc8ba4a0ea43314" + integrity sha512-J+fYWijkX4M4QKwf9dtu1oC0U6e6CEl8WhBp3Ad23yz2Hia0XCo6Pk/mp5CAFy4QBtQedTSkhW05AdtSOEoajQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "0.1.0" + "@radix-ui/react-compose-refs" "0.1.0" + "@radix-ui/react-primitive" "0.1.4" + "@radix-ui/react-use-body-pointer-events" "0.1.1" + "@radix-ui/react-use-callback-ref" "0.1.0" + "@radix-ui/react-use-escape-keydown" "0.1.0" + "@radix-ui/react-id@0.1.5": version "0.1.5" resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8" @@ -1584,6 +1597,24 @@ "@radix-ui/react-id" "0.1.5" "@radix-ui/react-primitive" "0.1.4" +"@radix-ui/react-portal@0.1.4": + version "0.1.4" + resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2" + integrity sha512-MO0wRy2eYRTZ/CyOri9NANCAtAtq89DEtg90gicaTlkCfdqCLEBsLb+/q66BZQTr3xX/Vq01nnVfc/TkCqoqvw== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-primitive" "0.1.4" + "@radix-ui/react-use-layout-effect" "0.1.0" + +"@radix-ui/react-presence@0.1.2": + version "0.1.2" + resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3" + integrity sha512-3BRlFZraooIUfRlyN+b/Xs5hq1lanOOo/+3h6Pwu2GMFjkGKKa4Rd51fcqGqnVlbr3jYg+WLuGyAV4KlgqwrQw== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-compose-refs" "0.1.0" + "@radix-ui/react-use-layout-effect" "0.1.0" + "@radix-ui/react-primitive@0.1.4": version "0.1.4" resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1" @@ -1615,6 +1646,32 @@ "@radix-ui/react-use-previous" "0.1.1" "@radix-ui/react-use-size" "0.1.1" +"@radix-ui/react-toast@^0.1.1": + version "0.1.1" + resolved "https://registry.npmmirror.com/@radix-ui/react-toast/-/react-toast-0.1.1.tgz#d544e796b307e56f1298e40f356f468680958e93" + integrity sha512-9JWC4mPP78OE6muDrpaPf/71dIeozppdcnik1IvsjTxZpDnt9PbTtQj94DdWjlCphbv3S5faD3KL0GOpqKBpTQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "0.1.0" + "@radix-ui/react-compose-refs" "0.1.0" + "@radix-ui/react-context" "0.1.1" + "@radix-ui/react-dismissable-layer" "0.1.5" + "@radix-ui/react-portal" "0.1.4" + "@radix-ui/react-presence" "0.1.2" + "@radix-ui/react-primitive" "0.1.4" + "@radix-ui/react-use-callback-ref" "0.1.0" + "@radix-ui/react-use-controllable-state" "0.1.0" + "@radix-ui/react-use-layout-effect" "0.1.0" + "@radix-ui/react-visually-hidden" "0.1.4" + +"@radix-ui/react-use-body-pointer-events@0.1.1": + version "0.1.1" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-body-pointer-events/-/react-use-body-pointer-events-0.1.1.tgz#63e7fd81ca7ffd30841deb584cd2b7f460df2597" + integrity sha512-R8leV2AWmJokTmERM8cMXFHWSiv/fzOLhG/JLmRBhLTAzOj37EQizssq4oW0Z29VcZy2tODMi9Pk/htxwb+xpA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-layout-effect" "0.1.0" + "@radix-ui/react-use-callback-ref@0.1.0": version "0.1.0" resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f" @@ -1630,6 +1687,14 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-use-callback-ref" "0.1.0" +"@radix-ui/react-use-escape-keydown@0.1.0": + version "0.1.0" + resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874" + integrity sha512-tDLZbTGFmvXaazUXXv8kYbiCcbAE8yKgng9s95d8fCO+Eundv0Jngbn/hKPhDDs4jj9ChwRX5cDDnlaN+ugYYQ== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-use-callback-ref" "0.1.0" + "@radix-ui/react-use-layout-effect@0.1.0": version "0.1.0" resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223" @@ -1651,6 +1716,14 @@ dependencies: "@babel/runtime" "^7.13.10" +"@radix-ui/react-visually-hidden@0.1.4": + version "0.1.4" + resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03" + integrity sha512-K/q6AEEzqeeEq/T0NPChvBqnwlp8Tl4NnQdrI/y8IOY7BRR+Ug0PEsVk6g48HJ7cA1//COugdxXXVVK/m0X1mA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-primitive" "0.1.4" + "@rollup/plugin-node-resolve@^7.1.1": version "7.1.3" resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz" diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 2b7768b..04e3b77 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -9,7 +9,7 @@ import torch from torch.hub import download_url_to_file, get_dir -def download_model(url): +def get_cache_path_by_url(url): parts = urlparse(url) hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") @@ -17,6 +17,11 @@ def download_model(url): os.makedirs(os.path.join(model_dir, "hub", "checkpoints")) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) + return cached_file + + +def download_model(url): + cached_file = get_cache_path_by_url(url) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 31059dc..21b4329 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -24,6 +24,11 @@ class InpaintModel: def init_model(self, device): ... + @staticmethod + @abc.abstractmethod + def is_downloaded() -> bool: + ... + @abc.abstractmethod def forward(self, image, mask, config: Config): """Input image and output image have same size diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index 20ea2f8..eba3860 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -5,7 +5,7 @@ import numpy as np import torch from loguru import logger -from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img +from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config @@ -36,12 +36,16 @@ class LaMa(InpaintModel): ) else: model_path = download_model(LAMA_MODEL_URL) - logger.info(f"Load LaMa model from: {model_path}") model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) model.eval() self.model = model + self.model_path = model_path + + @staticmethod + def is_downloaded() -> bool: + return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) def forward(self, image, mask, config: Config): """Input image and output image have same size diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index 0b433b9..fc7a0e0 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -10,7 +10,7 @@ from lama_cleaner.schema import Config torch.manual_seed(42) import torch.nn as nn from tqdm import tqdm -from lama_cleaner.helper import download_model, norm_img +from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \ timestep_embedding @@ -266,6 +266,7 @@ class DDIMSampler(object): def load_jit_model(url, device): model_path = download_model(url) + logger.info(f"Load LDM model from: {model_path}") model = torch.jit.load(model_path).to(device) model.eval() return model @@ -286,6 +287,15 @@ class LDM(InpaintModel): model = LatentDiffusion(self.diffusion_model, device) self.sampler = DDIMSampler(model) + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL), + get_cache_path_by_url(LDM_DECODE_MODEL_URL), + get_cache_path_by_url(LDM_ENCODE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + def forward(self, image, mask, config: Config): """ image: [H, W, C] RGB diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 105561e..933e734 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -21,6 +21,14 @@ class ModelManager: raise NotImplementedError(f"Not supported model: {name}") return model + def is_downloaded(self, name: str) -> bool: + if name == self.LAMA: + return LaMa.is_downloaded() + elif name == self.LDM: + return LDM.is_downloaded() + else: + raise NotImplementedError(f"Not supported model: {name}") + def __call__(self, image, mask, config: Config): return self.model(image, mask, config) diff --git a/main.py b/main.py index 0c917d6..708238b 100644 --- a/main.py +++ b/main.py @@ -125,7 +125,17 @@ def process(): ) -@app.route("/switch_model", methods=["POST"]) +@app.route("/model") +def current_model(): + return model.name, 200 + + +@app.route("/model_downloaded/") +def model_downloaded(name): + return str(model.is_downloaded(name)), 200 + + +@app.route("/model", methods=["POST"]) def switch_model(): new_name = request.form.get("name") if new_name == model.name: