make model switch work with toast
This commit is contained in:
parent
205286a414
commit
f7e1e073dc
@ -6,6 +6,7 @@
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^1.0.4",
|
||||
"@radix-ui/react-switch": "^0.1.5",
|
||||
"@radix-ui/react-toast": "^0.1.1",
|
||||
"@testing-library/jest-dom": "^5.14.1",
|
||||
"@testing-library/react": "^12.1.2",
|
||||
"@testing-library/user-event": "^13.5.0",
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Setting } from '../store/Atoms'
|
||||
import { Settings } from '../store/Atoms'
|
||||
import { dataURItoBlob } from '../utils'
|
||||
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
@ -6,7 +6,7 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
export default async function inpaint(
|
||||
imageFile: File,
|
||||
maskBase64: string,
|
||||
settings: Setting,
|
||||
settings: Settings,
|
||||
sizeLimit?: string
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
@ -43,8 +43,20 @@ export default async function inpaint(
|
||||
export function switchModel(name: string) {
|
||||
const fd = new FormData()
|
||||
fd.append('name', name)
|
||||
return fetch(`${API_ENDPOINT}/switch_model`, {
|
||||
return fetch(`${API_ENDPOINT}/model`, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
})
|
||||
}
|
||||
|
||||
export function currentModel() {
|
||||
return fetch(`${API_ENDPOINT}/model`, {
|
||||
method: 'GET',
|
||||
})
|
||||
}
|
||||
|
||||
export function modelDownloaded(name: string) {
|
||||
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
|
||||
method: 'GET',
|
||||
})
|
||||
}
|
||||
|
@ -1,26 +1,28 @@
|
||||
import React from 'react'
|
||||
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { switchModel } from '../../adapters/inpainting'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Modal from '../shared/Modal'
|
||||
import HDSettingBlock from './HDSettingBlock'
|
||||
import ModelSettingBlock from './ModelSettingBlock'
|
||||
|
||||
export default function SettingModal() {
|
||||
interface SettingModalProps {
|
||||
onClose: () => void
|
||||
}
|
||||
export default function SettingModal(props: SettingModalProps) {
|
||||
const { onClose } = props
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const onClose = () => {
|
||||
const handleOnClose = () => {
|
||||
setSettingState(old => {
|
||||
return { ...old, show: false }
|
||||
})
|
||||
|
||||
switchModel(setting.model)
|
||||
onClose()
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
onClose={onClose}
|
||||
onClose={handleOnClose}
|
||||
title="Settings"
|
||||
className="modal-setting"
|
||||
show={setting.show}
|
||||
|
@ -1,18 +1,99 @@
|
||||
import React from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import Editor from './Editor/Editor'
|
||||
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
||||
import SettingModal from './Settings/SettingsModal'
|
||||
import Toast from './shared/Toast'
|
||||
import { Settings, settingState, toastState } from '../store/Atoms'
|
||||
import {
|
||||
currentModel,
|
||||
modelDownloaded,
|
||||
switchModel,
|
||||
} from '../adapters/inpainting'
|
||||
import { AIModel } from './Settings/ModelSettingBlock'
|
||||
|
||||
interface WorkspaceProps {
|
||||
file: File
|
||||
}
|
||||
|
||||
const Workspace = ({ file }: WorkspaceProps) => {
|
||||
const [settings, setSettingState] = useRecoilState(settingState)
|
||||
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||
|
||||
const onSettingClose = async () => {
|
||||
const curModel = await currentModel().then(res => res.text())
|
||||
if (curModel === settings.model) {
|
||||
return
|
||||
}
|
||||
const downloaded = await modelDownloaded(settings.model).then(res =>
|
||||
res.text()
|
||||
)
|
||||
|
||||
const { model } = settings
|
||||
|
||||
let loadingMessage = `Switching to ${model} model`
|
||||
let loadingDuration = 3000
|
||||
if (downloaded === 'False') {
|
||||
loadingMessage = `Downloading ${model} model, this may take a while`
|
||||
loadingDuration = 9999999999
|
||||
}
|
||||
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: loadingMessage,
|
||||
state: 'loading',
|
||||
duration: loadingDuration,
|
||||
})
|
||||
|
||||
switchModel(model)
|
||||
.then(res => {
|
||||
if (res.ok) {
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: `Switch to ${model} model success`,
|
||||
state: 'success',
|
||||
duration: 3000,
|
||||
})
|
||||
} else {
|
||||
throw new Error('Server error')
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: `Switch to ${model} model failed`,
|
||||
state: 'error',
|
||||
duration: 3000,
|
||||
})
|
||||
setSettingState(old => {
|
||||
return { ...old, model: curModel as AIModel }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
currentModel()
|
||||
.then(res => res.text())
|
||||
.then(model => {
|
||||
setSettingState(old => {
|
||||
return { ...old, model: model as AIModel }
|
||||
})
|
||||
})
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<>
|
||||
<Editor file={file} />
|
||||
<SettingModal />
|
||||
<SettingModal onClose={onSettingClose} />
|
||||
<ShortcutsModal />
|
||||
<Toast
|
||||
{...toastVal}
|
||||
onOpenChange={(open: boolean) => {
|
||||
setToastState(old => {
|
||||
return { ...old, open }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
83
lama_cleaner/app/src/components/shared/Toast.scss
Normal file
83
lama_cleaner/app/src/components/shared/Toast.scss
Normal file
@ -0,0 +1,83 @@
|
||||
.toast-viewpoint {
|
||||
position: fixed;
|
||||
top: 48px;
|
||||
right: 0;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
padding: 25px;
|
||||
gap: 10px;
|
||||
max-width: 100vw;
|
||||
margin: 0;
|
||||
z-index: 999999;
|
||||
|
||||
&:focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
}
|
||||
|
||||
.toast-root {
|
||||
border: 1px solid var(--border-color-light);
|
||||
background-color: var(--page-bg);
|
||||
border-radius: 0.6rem;
|
||||
padding: 15px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
gap: 12px;
|
||||
|
||||
&[data-state='open'] {
|
||||
animation: slideIn 150ms cubic-bezier(0.16, 1, 0.3, 1);
|
||||
}
|
||||
|
||||
&[data-state='close'] {
|
||||
animation: opacityReveal 100ms ease-in forwards;
|
||||
}
|
||||
|
||||
&[data-state='cancel'] {
|
||||
transform: translateX(0);
|
||||
animation: transform 100ms ease-out;
|
||||
}
|
||||
|
||||
&.error {
|
||||
border: 1px solid var(--error-color);
|
||||
}
|
||||
|
||||
&.success {
|
||||
border: 1px solid var(--success-color);
|
||||
}
|
||||
}
|
||||
|
||||
.error-icon {
|
||||
height: 24px;
|
||||
width: 24px;
|
||||
color: var(--error-color);
|
||||
}
|
||||
|
||||
.success-icon {
|
||||
height: 24px;
|
||||
width: 24px;
|
||||
color: var(--success-color);
|
||||
}
|
||||
|
||||
.loading-icon {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
animation-name: spin;
|
||||
animation-duration: 1500ms;
|
||||
animation-iteration-count: infinite;
|
||||
transform-origin: center center;
|
||||
animation-timing-function: linear;
|
||||
}
|
||||
|
||||
.toast-icon {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.toast-desc {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 0;
|
||||
color: var(--text-color);
|
||||
min-width: 240px;
|
||||
}
|
81
lama_cleaner/app/src/components/shared/Toast.tsx
Normal file
81
lama_cleaner/app/src/components/shared/Toast.tsx
Normal file
@ -0,0 +1,81 @@
|
||||
import * as React from 'react'
|
||||
import * as ToastPrimitive from '@radix-ui/react-toast'
|
||||
import { ToastProps } from '@radix-ui/react-toast'
|
||||
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline'
|
||||
|
||||
const LoadingIcon = () => {
|
||||
return (
|
||||
<span className="loading-icon">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<line x1="12" y1="2" x2="12" y2="6" />
|
||||
<line x1="12" y1="18" x2="12" y2="22" />
|
||||
<line x1="4.93" y1="4.93" x2="7.76" y2="7.76" />
|
||||
<line x1="16.24" y1="16.24" x2="19.07" y2="19.07" />
|
||||
<line x1="2" y1="12" x2="6" y2="12" />
|
||||
<line x1="18" y1="12" x2="22" y2="12" />
|
||||
<line x1="4.93" y1="19.07" x2="7.76" y2="16.24" />
|
||||
<line x1="16.24" y1="7.76" x2="19.07" y2="4.93" />
|
||||
</svg>
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
export type ToastState = 'default' | 'error' | 'loading' | 'success'
|
||||
|
||||
interface MyToastProps extends ToastProps {
|
||||
desc: string
|
||||
state?: ToastState
|
||||
}
|
||||
|
||||
const Toast = React.forwardRef<
|
||||
React.ElementRef<typeof ToastPrimitive.Root>,
|
||||
MyToastProps
|
||||
>((props, forwardedRef) => {
|
||||
const { state, desc, ...itemProps } = props
|
||||
|
||||
const getIcon = () => {
|
||||
switch (state) {
|
||||
case 'error':
|
||||
return <ExclamationCircleIcon className="error-icon" />
|
||||
case 'success':
|
||||
return <CheckIcon className="success-icon" />
|
||||
case 'loading':
|
||||
return <LoadingIcon />
|
||||
default:
|
||||
return <></>
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<ToastPrimitive.Provider>
|
||||
<ToastPrimitive.Root
|
||||
{...itemProps}
|
||||
ref={forwardedRef}
|
||||
className={`toast-root ${state}`}
|
||||
>
|
||||
<div className="toast-icon">{getIcon()}</div>
|
||||
<ToastPrimitive.Description className="toast-desc">
|
||||
{desc}
|
||||
</ToastPrimitive.Description>
|
||||
</ToastPrimitive.Root>
|
||||
<ToastPrimitive.Viewport className="toast-viewpoint" />
|
||||
</ToastPrimitive.Provider>
|
||||
)
|
||||
})
|
||||
|
||||
Toast.defaultProps = {
|
||||
desc: '',
|
||||
state: 'loading',
|
||||
}
|
||||
|
||||
export default Toast
|
@ -1,18 +1,36 @@
|
||||
import { atom } from 'recoil'
|
||||
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
||||
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||
import { ToastState } from '../components/shared/Toast'
|
||||
|
||||
export const fileState = atom<File | undefined>({
|
||||
key: 'fileState',
|
||||
default: undefined,
|
||||
})
|
||||
|
||||
interface ToastAtomState {
|
||||
open: boolean
|
||||
desc: string
|
||||
state: ToastState
|
||||
duration: number
|
||||
}
|
||||
|
||||
export const toastState = atom<ToastAtomState>({
|
||||
key: 'toastState',
|
||||
default: {
|
||||
open: false,
|
||||
desc: '',
|
||||
state: 'default',
|
||||
duration: 3000,
|
||||
},
|
||||
})
|
||||
|
||||
export const shortcutsState = atom<boolean>({
|
||||
key: 'shortcutsState',
|
||||
default: false,
|
||||
})
|
||||
|
||||
export interface Setting {
|
||||
export interface Settings {
|
||||
show: boolean
|
||||
saveImageBesideOrigin: boolean
|
||||
model: AIModel
|
||||
@ -27,16 +45,18 @@ export interface Setting {
|
||||
ldmSteps: number
|
||||
}
|
||||
|
||||
export const settingState = atom<Setting>({
|
||||
export const settingStateDefault = {
|
||||
show: false,
|
||||
saveImageBesideOrigin: false,
|
||||
model: AIModel.LAMA,
|
||||
ldmSteps: 50,
|
||||
hdStrategy: HDStrategy.RESIZE,
|
||||
hdStrategyResizeLimit: 2048,
|
||||
hdStrategyCropTrigerSize: 2048,
|
||||
hdStrategyCropMargin: 128,
|
||||
}
|
||||
|
||||
export const settingState = atom<Settings>({
|
||||
key: 'settingsState',
|
||||
default: {
|
||||
show: false,
|
||||
saveImageBesideOrigin: false,
|
||||
model: AIModel.LAMA,
|
||||
hdStrategy: HDStrategy.RESIZE,
|
||||
hdStrategyResizeLimit: 2048,
|
||||
hdStrategyCropTrigerSize: 2048,
|
||||
hdStrategyCropMargin: 128,
|
||||
ldmSteps: 50,
|
||||
},
|
||||
default: settingStateDefault,
|
||||
})
|
||||
|
@ -37,3 +37,21 @@
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
0% {
|
||||
transform: translateX(calc(100% + 25px));
|
||||
}
|
||||
100% {
|
||||
transform: translateX(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,10 @@
|
||||
--yellow-accent: #ffcc00;
|
||||
--link-color: rgb(0, 0, 0);
|
||||
--border-color: rgb(100, 100, 120);
|
||||
--border-color-light: rgba(100, 100, 120, 0.5);
|
||||
|
||||
--error-color: rgb(239, 68, 68);
|
||||
--success-color: rgb(16, 185, 129);
|
||||
|
||||
// Editor
|
||||
--editor-toolkit-bg: rgba(255, 255, 255, 0.5);
|
||||
|
@ -7,6 +7,7 @@
|
||||
--yellow-accent: #ffcc00;
|
||||
--link-color: var(--yellow-accent);
|
||||
--border-color: rgb(100, 100, 120);
|
||||
--border-color-light: rgba(102, 102, 102);
|
||||
|
||||
// Editor
|
||||
--editor-toolkit-bg: rgba(0, 0, 0, 0.5);
|
||||
|
@ -20,6 +20,7 @@
|
||||
@use '../components/shared/Selector';
|
||||
@use '../components/shared/Switch';
|
||||
@use '../components/shared/NumberInput';
|
||||
@use '../components/shared/Toast';
|
||||
|
||||
// Main CSS
|
||||
*,
|
||||
|
@ -1565,6 +1565,19 @@
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-dismissable-layer@0.1.5":
|
||||
version "0.1.5"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-0.1.5.tgz#9379032351e79028d472733a5cc8ba4a0ea43314"
|
||||
integrity sha512-J+fYWijkX4M4QKwf9dtu1oC0U6e6CEl8WhBp3Ad23yz2Hia0XCo6Pk/mp5CAFy4QBtQedTSkhW05AdtSOEoajQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/primitive" "0.1.0"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-body-pointer-events" "0.1.1"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
"@radix-ui/react-use-escape-keydown" "0.1.0"
|
||||
|
||||
"@radix-ui/react-id@0.1.5":
|
||||
version "0.1.5"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
|
||||
@ -1584,6 +1597,24 @@
|
||||
"@radix-ui/react-id" "0.1.5"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
|
||||
"@radix-ui/react-portal@0.1.4":
|
||||
version "0.1.4"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2"
|
||||
integrity sha512-MO0wRy2eYRTZ/CyOri9NANCAtAtq89DEtg90gicaTlkCfdqCLEBsLb+/q66BZQTr3xX/Vq01nnVfc/TkCqoqvw==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-presence@0.1.2":
|
||||
version "0.1.2"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3"
|
||||
integrity sha512-3BRlFZraooIUfRlyN+b/Xs5hq1lanOOo/+3h6Pwu2GMFjkGKKa4Rd51fcqGqnVlbr3jYg+WLuGyAV4KlgqwrQw==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-primitive@0.1.4":
|
||||
version "0.1.4"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
|
||||
@ -1615,6 +1646,32 @@
|
||||
"@radix-ui/react-use-previous" "0.1.1"
|
||||
"@radix-ui/react-use-size" "0.1.1"
|
||||
|
||||
"@radix-ui/react-toast@^0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-toast/-/react-toast-0.1.1.tgz#d544e796b307e56f1298e40f356f468680958e93"
|
||||
integrity sha512-9JWC4mPP78OE6muDrpaPf/71dIeozppdcnik1IvsjTxZpDnt9PbTtQj94DdWjlCphbv3S5faD3KL0GOpqKBpTQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/primitive" "0.1.0"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-context" "0.1.1"
|
||||
"@radix-ui/react-dismissable-layer" "0.1.5"
|
||||
"@radix-ui/react-portal" "0.1.4"
|
||||
"@radix-ui/react-presence" "0.1.2"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
"@radix-ui/react-use-controllable-state" "0.1.0"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
"@radix-ui/react-visually-hidden" "0.1.4"
|
||||
|
||||
"@radix-ui/react-use-body-pointer-events@0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-body-pointer-events/-/react-use-body-pointer-events-0.1.1.tgz#63e7fd81ca7ffd30841deb584cd2b7f460df2597"
|
||||
integrity sha512-R8leV2AWmJokTmERM8cMXFHWSiv/fzOLhG/JLmRBhLTAzOj37EQizssq4oW0Z29VcZy2tODMi9Pk/htxwb+xpA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-use-callback-ref@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f"
|
||||
@ -1630,6 +1687,14 @@
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
|
||||
"@radix-ui/react-use-escape-keydown@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874"
|
||||
integrity sha512-tDLZbTGFmvXaazUXXv8kYbiCcbAE8yKgng9s95d8fCO+Eundv0Jngbn/hKPhDDs4jj9ChwRX5cDDnlaN+ugYYQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
|
||||
"@radix-ui/react-use-layout-effect@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
|
||||
@ -1651,6 +1716,14 @@
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-visually-hidden@0.1.4":
|
||||
version "0.1.4"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03"
|
||||
integrity sha512-K/q6AEEzqeeEq/T0NPChvBqnwlp8Tl4NnQdrI/y8IOY7BRR+Ug0PEsVk6g48HJ7cA1//COugdxXXVVK/m0X1mA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
|
||||
"@rollup/plugin-node-resolve@^7.1.1":
|
||||
version "7.1.3"
|
||||
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
|
||||
|
||||
def download_model(url):
|
||||
def get_cache_path_by_url(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
@ -17,6 +17,11 @@ def download_model(url):
|
||||
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
return cached_file
|
||||
|
||||
|
||||
def download_model(url):
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
|
@ -24,6 +24,11 @@ class InpaintModel:
|
||||
def init_model(self, device):
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@ -36,12 +36,16 @@ class LaMa(InpaintModel):
|
||||
)
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
|
||||
logger.info(f"Load LaMa model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.model_path = model_path
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
|
@ -10,7 +10,7 @@ from lama_cleaner.schema import Config
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from lama_cleaner.helper import download_model, norm_img
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
timestep_embedding
|
||||
|
||||
@ -266,6 +266,7 @@ class DDIMSampler(object):
|
||||
|
||||
def load_jit_model(url, device):
|
||||
model_path = download_model(url)
|
||||
logger.info(f"Load LDM model from: {model_path}")
|
||||
model = torch.jit.load(model_path).to(device)
|
||||
model.eval()
|
||||
return model
|
||||
@ -286,6 +287,15 @@ class LDM(InpaintModel):
|
||||
model = LatentDiffusion(self.diffusion_model, device)
|
||||
self.sampler = DDIMSampler(model)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
model_paths = [
|
||||
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
||||
]
|
||||
return all([os.path.exists(it) for it in model_paths])
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
|
@ -21,6 +21,14 @@ class ModelManager:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
return model
|
||||
|
||||
def is_downloaded(self, name: str) -> bool:
|
||||
if name == self.LAMA:
|
||||
return LaMa.is_downloaded()
|
||||
elif name == self.LDM:
|
||||
return LDM.is_downloaded()
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
return self.model(image, mask, config)
|
||||
|
||||
|
12
main.py
12
main.py
@ -125,7 +125,17 @@ def process():
|
||||
)
|
||||
|
||||
|
||||
@app.route("/switch_model", methods=["POST"])
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
|
||||
|
||||
@app.route("/model_downloaded/<name>")
|
||||
def model_downloaded(name):
|
||||
return str(model.is_downloaded(name)), 200
|
||||
|
||||
|
||||
@app.route("/model", methods=["POST"])
|
||||
def switch_model():
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
|
Loading…
Reference in New Issue
Block a user