make model switch work with toast

This commit is contained in:
Sanster 2022-04-17 23:31:12 +08:00
parent 205286a414
commit f7e1e073dc
18 changed files with 447 additions and 28 deletions

View File

@ -6,6 +6,7 @@
"dependencies": {
"@heroicons/react": "^1.0.4",
"@radix-ui/react-switch": "^0.1.5",
"@radix-ui/react-toast": "^0.1.1",
"@testing-library/jest-dom": "^5.14.1",
"@testing-library/react": "^12.1.2",
"@testing-library/user-event": "^13.5.0",

View File

@ -1,4 +1,4 @@
import { Setting } from '../store/Atoms'
import { Settings } from '../store/Atoms'
import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -6,7 +6,7 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
export default async function inpaint(
imageFile: File,
maskBase64: string,
settings: Setting,
settings: Settings,
sizeLimit?: string
) {
// 1080, 2000, Original
@ -43,8 +43,20 @@ export default async function inpaint(
export function switchModel(name: string) {
const fd = new FormData()
fd.append('name', name)
return fetch(`${API_ENDPOINT}/switch_model`, {
return fetch(`${API_ENDPOINT}/model`, {
method: 'POST',
body: fd,
})
}
export function currentModel() {
return fetch(`${API_ENDPOINT}/model`, {
method: 'GET',
})
}
export function modelDownloaded(name: string) {
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
method: 'GET',
})
}

View File

@ -1,26 +1,28 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { switchModel } from '../../adapters/inpainting'
import { settingState } from '../../store/Atoms'
import Modal from '../shared/Modal'
import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock'
export default function SettingModal() {
interface SettingModalProps {
onClose: () => void
}
export default function SettingModal(props: SettingModalProps) {
const { onClose } = props
const [setting, setSettingState] = useRecoilState(settingState)
const onClose = () => {
const handleOnClose = () => {
setSettingState(old => {
return { ...old, show: false }
})
switchModel(setting.model)
onClose()
}
return (
<Modal
onClose={onClose}
onClose={handleOnClose}
title="Settings"
className="modal-setting"
show={setting.show}

View File

@ -1,18 +1,99 @@
import React from 'react'
import React, { useEffect } from 'react'
import { useRecoilState } from 'recoil'
import Editor from './Editor/Editor'
import ShortcutsModal from './Shortcuts/ShortcutsModal'
import SettingModal from './Settings/SettingsModal'
import Toast from './shared/Toast'
import { Settings, settingState, toastState } from '../store/Atoms'
import {
currentModel,
modelDownloaded,
switchModel,
} from '../adapters/inpainting'
import { AIModel } from './Settings/ModelSettingBlock'
interface WorkspaceProps {
file: File
}
const Workspace = ({ file }: WorkspaceProps) => {
const [settings, setSettingState] = useRecoilState(settingState)
const [toastVal, setToastState] = useRecoilState(toastState)
const onSettingClose = async () => {
const curModel = await currentModel().then(res => res.text())
if (curModel === settings.model) {
return
}
const downloaded = await modelDownloaded(settings.model).then(res =>
res.text()
)
const { model } = settings
let loadingMessage = `Switching to ${model} model`
let loadingDuration = 3000
if (downloaded === 'False') {
loadingMessage = `Downloading ${model} model, this may take a while`
loadingDuration = 9999999999
}
setToastState({
open: true,
desc: loadingMessage,
state: 'loading',
duration: loadingDuration,
})
switchModel(model)
.then(res => {
if (res.ok) {
setToastState({
open: true,
desc: `Switch to ${model} model success`,
state: 'success',
duration: 3000,
})
} else {
throw new Error('Server error')
}
})
.catch(() => {
setToastState({
open: true,
desc: `Switch to ${model} model failed`,
state: 'error',
duration: 3000,
})
setSettingState(old => {
return { ...old, model: curModel as AIModel }
})
})
}
useEffect(() => {
currentModel()
.then(res => res.text())
.then(model => {
setSettingState(old => {
return { ...old, model: model as AIModel }
})
})
}, [])
return (
<>
<Editor file={file} />
<SettingModal />
<SettingModal onClose={onSettingClose} />
<ShortcutsModal />
<Toast
{...toastVal}
onOpenChange={(open: boolean) => {
setToastState(old => {
return { ...old, open }
})
}}
/>
</>
)
}

View File

@ -0,0 +1,83 @@
.toast-viewpoint {
position: fixed;
top: 48px;
right: 0;
display: flex;
flex-direction: row;
padding: 25px;
gap: 10px;
max-width: 100vw;
margin: 0;
z-index: 999999;
&:focus-visible {
outline: none;
}
}
.toast-root {
border: 1px solid var(--border-color-light);
background-color: var(--page-bg);
border-radius: 0.6rem;
padding: 15px;
display: flex;
align-items: center;
gap: 12px;
&[data-state='open'] {
animation: slideIn 150ms cubic-bezier(0.16, 1, 0.3, 1);
}
&[data-state='close'] {
animation: opacityReveal 100ms ease-in forwards;
}
&[data-state='cancel'] {
transform: translateX(0);
animation: transform 100ms ease-out;
}
&.error {
border: 1px solid var(--error-color);
}
&.success {
border: 1px solid var(--success-color);
}
}
.error-icon {
height: 24px;
width: 24px;
color: var(--error-color);
}
.success-icon {
height: 24px;
width: 24px;
color: var(--success-color);
}
.loading-icon {
display: flex;
align-items: center;
animation-name: spin;
animation-duration: 1500ms;
animation-iteration-count: infinite;
transform-origin: center center;
animation-timing-function: linear;
}
.toast-icon {
display: flex;
align-items: center;
}
.toast-desc {
display: flex;
align-items: center;
margin: 0;
color: var(--text-color);
min-width: 240px;
}

View File

@ -0,0 +1,81 @@
import * as React from 'react'
import * as ToastPrimitive from '@radix-ui/react-toast'
import { ToastProps } from '@radix-ui/react-toast'
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline'
const LoadingIcon = () => {
return (
<span className="loading-icon">
<svg
xmlns="http://www.w3.org/2000/svg"
width="20"
height="20"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<line x1="12" y1="2" x2="12" y2="6" />
<line x1="12" y1="18" x2="12" y2="22" />
<line x1="4.93" y1="4.93" x2="7.76" y2="7.76" />
<line x1="16.24" y1="16.24" x2="19.07" y2="19.07" />
<line x1="2" y1="12" x2="6" y2="12" />
<line x1="18" y1="12" x2="22" y2="12" />
<line x1="4.93" y1="19.07" x2="7.76" y2="16.24" />
<line x1="16.24" y1="7.76" x2="19.07" y2="4.93" />
</svg>
</span>
)
}
export type ToastState = 'default' | 'error' | 'loading' | 'success'
interface MyToastProps extends ToastProps {
desc: string
state?: ToastState
}
const Toast = React.forwardRef<
React.ElementRef<typeof ToastPrimitive.Root>,
MyToastProps
>((props, forwardedRef) => {
const { state, desc, ...itemProps } = props
const getIcon = () => {
switch (state) {
case 'error':
return <ExclamationCircleIcon className="error-icon" />
case 'success':
return <CheckIcon className="success-icon" />
case 'loading':
return <LoadingIcon />
default:
return <></>
}
}
return (
<ToastPrimitive.Provider>
<ToastPrimitive.Root
{...itemProps}
ref={forwardedRef}
className={`toast-root ${state}`}
>
<div className="toast-icon">{getIcon()}</div>
<ToastPrimitive.Description className="toast-desc">
{desc}
</ToastPrimitive.Description>
</ToastPrimitive.Root>
<ToastPrimitive.Viewport className="toast-viewpoint" />
</ToastPrimitive.Provider>
)
})
Toast.defaultProps = {
desc: '',
state: 'loading',
}
export default Toast

View File

@ -1,18 +1,36 @@
import { atom } from 'recoil'
import { HDStrategy } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Settings/ModelSettingBlock'
import { ToastState } from '../components/shared/Toast'
export const fileState = atom<File | undefined>({
key: 'fileState',
default: undefined,
})
interface ToastAtomState {
open: boolean
desc: string
state: ToastState
duration: number
}
export const toastState = atom<ToastAtomState>({
key: 'toastState',
default: {
open: false,
desc: '',
state: 'default',
duration: 3000,
},
})
export const shortcutsState = atom<boolean>({
key: 'shortcutsState',
default: false,
})
export interface Setting {
export interface Settings {
show: boolean
saveImageBesideOrigin: boolean
model: AIModel
@ -27,16 +45,18 @@ export interface Setting {
ldmSteps: number
}
export const settingState = atom<Setting>({
export const settingStateDefault = {
show: false,
saveImageBesideOrigin: false,
model: AIModel.LAMA,
ldmSteps: 50,
hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048,
hdStrategyCropMargin: 128,
}
export const settingState = atom<Settings>({
key: 'settingsState',
default: {
show: false,
saveImageBesideOrigin: false,
model: AIModel.LAMA,
hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048,
hdStrategyCropMargin: 128,
ldmSteps: 50,
},
default: settingStateDefault,
})

View File

@ -37,3 +37,21 @@
transform: translateY(0);
}
}
@keyframes slideIn {
0% {
transform: translateX(calc(100% + 25px));
}
100% {
transform: translateX(0);
}
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

View File

@ -7,6 +7,10 @@
--yellow-accent: #ffcc00;
--link-color: rgb(0, 0, 0);
--border-color: rgb(100, 100, 120);
--border-color-light: rgba(100, 100, 120, 0.5);
--error-color: rgb(239, 68, 68);
--success-color: rgb(16, 185, 129);
// Editor
--editor-toolkit-bg: rgba(255, 255, 255, 0.5);

View File

@ -7,6 +7,7 @@
--yellow-accent: #ffcc00;
--link-color: var(--yellow-accent);
--border-color: rgb(100, 100, 120);
--border-color-light: rgba(102, 102, 102);
// Editor
--editor-toolkit-bg: rgba(0, 0, 0, 0.5);

View File

@ -20,6 +20,7 @@
@use '../components/shared/Selector';
@use '../components/shared/Switch';
@use '../components/shared/NumberInput';
@use '../components/shared/Toast';
// Main CSS
*,

View File

@ -1565,6 +1565,19 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-dismissable-layer@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-0.1.5.tgz#9379032351e79028d472733a5cc8ba4a0ea43314"
integrity sha512-J+fYWijkX4M4QKwf9dtu1oC0U6e6CEl8WhBp3Ad23yz2Hia0XCo6Pk/mp5CAFy4QBtQedTSkhW05AdtSOEoajQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "0.1.0"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-body-pointer-events" "0.1.1"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-escape-keydown" "0.1.0"
"@radix-ui/react-id@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
@ -1584,6 +1597,24 @@
"@radix-ui/react-id" "0.1.5"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-portal@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2"
integrity sha512-MO0wRy2eYRTZ/CyOri9NANCAtAtq89DEtg90gicaTlkCfdqCLEBsLb+/q66BZQTr3xX/Vq01nnVfc/TkCqoqvw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-presence@0.1.2":
version "0.1.2"
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3"
integrity sha512-3BRlFZraooIUfRlyN+b/Xs5hq1lanOOo/+3h6Pwu2GMFjkGKKa4Rd51fcqGqnVlbr3jYg+WLuGyAV4KlgqwrQw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-primitive@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
@ -1615,6 +1646,32 @@
"@radix-ui/react-use-previous" "0.1.1"
"@radix-ui/react-use-size" "0.1.1"
"@radix-ui/react-toast@^0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-toast/-/react-toast-0.1.1.tgz#d544e796b307e56f1298e40f356f468680958e93"
integrity sha512-9JWC4mPP78OE6muDrpaPf/71dIeozppdcnik1IvsjTxZpDnt9PbTtQj94DdWjlCphbv3S5faD3KL0GOpqKBpTQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "0.1.0"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-context" "0.1.1"
"@radix-ui/react-dismissable-layer" "0.1.5"
"@radix-ui/react-portal" "0.1.4"
"@radix-ui/react-presence" "0.1.2"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-controllable-state" "0.1.0"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-visually-hidden" "0.1.4"
"@radix-ui/react-use-body-pointer-events@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-body-pointer-events/-/react-use-body-pointer-events-0.1.1.tgz#63e7fd81ca7ffd30841deb584cd2b7f460df2597"
integrity sha512-R8leV2AWmJokTmERM8cMXFHWSiv/fzOLhG/JLmRBhLTAzOj37EQizssq4oW0Z29VcZy2tODMi9Pk/htxwb+xpA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-use-callback-ref@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f"
@ -1630,6 +1687,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-escape-keydown@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874"
integrity sha512-tDLZbTGFmvXaazUXXv8kYbiCcbAE8yKgng9s95d8fCO+Eundv0Jngbn/hKPhDDs4jj9ChwRX5cDDnlaN+ugYYQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-layout-effect@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
@ -1651,6 +1716,14 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-visually-hidden@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03"
integrity sha512-K/q6AEEzqeeEq/T0NPChvBqnwlp8Tl4NnQdrI/y8IOY7BRR+Ug0PEsVk6g48HJ7cA1//COugdxXXVVK/m0X1mA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "0.1.4"
"@rollup/plugin-node-resolve@^7.1.1":
version "7.1.3"
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"

View File

@ -9,7 +9,7 @@ import torch
from torch.hub import download_url_to_file, get_dir
def download_model(url):
def get_cache_path_by_url(url):
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
@ -17,6 +17,11 @@ def download_model(url):
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
return cached_file
def download_model(url):
cached_file = get_cache_path_by_url(url)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None

View File

@ -24,6 +24,11 @@ class InpaintModel:
def init_model(self, device):
...
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool:
...
@abc.abstractmethod
def forward(self, image, mask, config: Config):
"""Input image and output image have same size

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
@ -36,12 +36,16 @@ class LaMa(InpaintModel):
)
else:
model_path = download_model(LAMA_MODEL_URL)
logger.info(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
self.model_path = model_path
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
def forward(self, image, mask, config: Config):
"""Input image and output image have same size

View File

@ -10,7 +10,7 @@ from lama_cleaner.schema import Config
torch.manual_seed(42)
import torch.nn as nn
from tqdm import tqdm
from lama_cleaner.helper import download_model, norm_img
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
timestep_embedding
@ -266,6 +266,7 @@ class DDIMSampler(object):
def load_jit_model(url, device):
model_path = download_model(url)
logger.info(f"Load LDM model from: {model_path}")
model = torch.jit.load(model_path).to(device)
model.eval()
return model
@ -286,6 +287,15 @@ class LDM(InpaintModel):
model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
@staticmethod
def is_downloaded() -> bool:
model_paths = [
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
]
return all([os.path.exists(it) for it in model_paths])
def forward(self, image, mask, config: Config):
"""
image: [H, W, C] RGB

View File

@ -21,6 +21,14 @@ class ModelManager:
raise NotImplementedError(f"Not supported model: {name}")
return model
def is_downloaded(self, name: str) -> bool:
if name == self.LAMA:
return LaMa.is_downloaded()
elif name == self.LDM:
return LDM.is_downloaded()
else:
raise NotImplementedError(f"Not supported model: {name}")
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)

12
main.py
View File

@ -125,7 +125,17 @@ def process():
)
@app.route("/switch_model", methods=["POST"])
@app.route("/model")
def current_model():
return model.name, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name: