make model switch work with toast
This commit is contained in:
parent
205286a414
commit
f7e1e073dc
@ -6,6 +6,7 @@
|
|||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@heroicons/react": "^1.0.4",
|
"@heroicons/react": "^1.0.4",
|
||||||
"@radix-ui/react-switch": "^0.1.5",
|
"@radix-ui/react-switch": "^0.1.5",
|
||||||
|
"@radix-ui/react-toast": "^0.1.1",
|
||||||
"@testing-library/jest-dom": "^5.14.1",
|
"@testing-library/jest-dom": "^5.14.1",
|
||||||
"@testing-library/react": "^12.1.2",
|
"@testing-library/react": "^12.1.2",
|
||||||
"@testing-library/user-event": "^13.5.0",
|
"@testing-library/user-event": "^13.5.0",
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Setting } from '../store/Atoms'
|
import { Settings } from '../store/Atoms'
|
||||||
import { dataURItoBlob } from '../utils'
|
import { dataURItoBlob } from '../utils'
|
||||||
|
|
||||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||||
@ -6,7 +6,7 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
|||||||
export default async function inpaint(
|
export default async function inpaint(
|
||||||
imageFile: File,
|
imageFile: File,
|
||||||
maskBase64: string,
|
maskBase64: string,
|
||||||
settings: Setting,
|
settings: Settings,
|
||||||
sizeLimit?: string
|
sizeLimit?: string
|
||||||
) {
|
) {
|
||||||
// 1080, 2000, Original
|
// 1080, 2000, Original
|
||||||
@ -43,8 +43,20 @@ export default async function inpaint(
|
|||||||
export function switchModel(name: string) {
|
export function switchModel(name: string) {
|
||||||
const fd = new FormData()
|
const fd = new FormData()
|
||||||
fd.append('name', name)
|
fd.append('name', name)
|
||||||
return fetch(`${API_ENDPOINT}/switch_model`, {
|
return fetch(`${API_ENDPOINT}/model`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: fd,
|
body: fd,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function currentModel() {
|
||||||
|
return fetch(`${API_ENDPOINT}/model`, {
|
||||||
|
method: 'GET',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function modelDownloaded(name: string) {
|
||||||
|
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
|
||||||
|
method: 'GET',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -1,26 +1,28 @@
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
|
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import { switchModel } from '../../adapters/inpainting'
|
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import Modal from '../shared/Modal'
|
import Modal from '../shared/Modal'
|
||||||
import HDSettingBlock from './HDSettingBlock'
|
import HDSettingBlock from './HDSettingBlock'
|
||||||
import ModelSettingBlock from './ModelSettingBlock'
|
import ModelSettingBlock from './ModelSettingBlock'
|
||||||
|
|
||||||
export default function SettingModal() {
|
interface SettingModalProps {
|
||||||
|
onClose: () => void
|
||||||
|
}
|
||||||
|
export default function SettingModal(props: SettingModalProps) {
|
||||||
|
const { onClose } = props
|
||||||
const [setting, setSettingState] = useRecoilState(settingState)
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
|
||||||
const onClose = () => {
|
const handleOnClose = () => {
|
||||||
setSettingState(old => {
|
setSettingState(old => {
|
||||||
return { ...old, show: false }
|
return { ...old, show: false }
|
||||||
})
|
})
|
||||||
|
onClose()
|
||||||
switchModel(setting.model)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
onClose={onClose}
|
onClose={handleOnClose}
|
||||||
title="Settings"
|
title="Settings"
|
||||||
className="modal-setting"
|
className="modal-setting"
|
||||||
show={setting.show}
|
show={setting.show}
|
||||||
|
@ -1,18 +1,99 @@
|
|||||||
import React from 'react'
|
import React, { useEffect } from 'react'
|
||||||
|
import { useRecoilState } from 'recoil'
|
||||||
import Editor from './Editor/Editor'
|
import Editor from './Editor/Editor'
|
||||||
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
||||||
import SettingModal from './Settings/SettingsModal'
|
import SettingModal from './Settings/SettingsModal'
|
||||||
|
import Toast from './shared/Toast'
|
||||||
|
import { Settings, settingState, toastState } from '../store/Atoms'
|
||||||
|
import {
|
||||||
|
currentModel,
|
||||||
|
modelDownloaded,
|
||||||
|
switchModel,
|
||||||
|
} from '../adapters/inpainting'
|
||||||
|
import { AIModel } from './Settings/ModelSettingBlock'
|
||||||
|
|
||||||
interface WorkspaceProps {
|
interface WorkspaceProps {
|
||||||
file: File
|
file: File
|
||||||
}
|
}
|
||||||
|
|
||||||
const Workspace = ({ file }: WorkspaceProps) => {
|
const Workspace = ({ file }: WorkspaceProps) => {
|
||||||
|
const [settings, setSettingState] = useRecoilState(settingState)
|
||||||
|
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||||
|
|
||||||
|
const onSettingClose = async () => {
|
||||||
|
const curModel = await currentModel().then(res => res.text())
|
||||||
|
if (curModel === settings.model) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const downloaded = await modelDownloaded(settings.model).then(res =>
|
||||||
|
res.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
const { model } = settings
|
||||||
|
|
||||||
|
let loadingMessage = `Switching to ${model} model`
|
||||||
|
let loadingDuration = 3000
|
||||||
|
if (downloaded === 'False') {
|
||||||
|
loadingMessage = `Downloading ${model} model, this may take a while`
|
||||||
|
loadingDuration = 9999999999
|
||||||
|
}
|
||||||
|
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: loadingMessage,
|
||||||
|
state: 'loading',
|
||||||
|
duration: loadingDuration,
|
||||||
|
})
|
||||||
|
|
||||||
|
switchModel(model)
|
||||||
|
.then(res => {
|
||||||
|
if (res.ok) {
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: `Switch to ${model} model success`,
|
||||||
|
state: 'success',
|
||||||
|
duration: 3000,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
throw new Error('Server error')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: `Switch to ${model} model failed`,
|
||||||
|
state: 'error',
|
||||||
|
duration: 3000,
|
||||||
|
})
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, model: curModel as AIModel }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
currentModel()
|
||||||
|
.then(res => res.text())
|
||||||
|
.then(model => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, model: model as AIModel }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}, [])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Editor file={file} />
|
<Editor file={file} />
|
||||||
<SettingModal />
|
<SettingModal onClose={onSettingClose} />
|
||||||
<ShortcutsModal />
|
<ShortcutsModal />
|
||||||
|
<Toast
|
||||||
|
{...toastVal}
|
||||||
|
onOpenChange={(open: boolean) => {
|
||||||
|
setToastState(old => {
|
||||||
|
return { ...old, open }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
83
lama_cleaner/app/src/components/shared/Toast.scss
Normal file
83
lama_cleaner/app/src/components/shared/Toast.scss
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
.toast-viewpoint {
|
||||||
|
position: fixed;
|
||||||
|
top: 48px;
|
||||||
|
right: 0;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
padding: 25px;
|
||||||
|
gap: 10px;
|
||||||
|
max-width: 100vw;
|
||||||
|
margin: 0;
|
||||||
|
z-index: 999999;
|
||||||
|
|
||||||
|
&:focus-visible {
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.toast-root {
|
||||||
|
border: 1px solid var(--border-color-light);
|
||||||
|
background-color: var(--page-bg);
|
||||||
|
border-radius: 0.6rem;
|
||||||
|
padding: 15px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
|
||||||
|
gap: 12px;
|
||||||
|
|
||||||
|
&[data-state='open'] {
|
||||||
|
animation: slideIn 150ms cubic-bezier(0.16, 1, 0.3, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
&[data-state='close'] {
|
||||||
|
animation: opacityReveal 100ms ease-in forwards;
|
||||||
|
}
|
||||||
|
|
||||||
|
&[data-state='cancel'] {
|
||||||
|
transform: translateX(0);
|
||||||
|
animation: transform 100ms ease-out;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.error {
|
||||||
|
border: 1px solid var(--error-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
&.success {
|
||||||
|
border: 1px solid var(--success-color);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-icon {
|
||||||
|
height: 24px;
|
||||||
|
width: 24px;
|
||||||
|
color: var(--error-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.success-icon {
|
||||||
|
height: 24px;
|
||||||
|
width: 24px;
|
||||||
|
color: var(--success-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading-icon {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
animation-name: spin;
|
||||||
|
animation-duration: 1500ms;
|
||||||
|
animation-iteration-count: infinite;
|
||||||
|
transform-origin: center center;
|
||||||
|
animation-timing-function: linear;
|
||||||
|
}
|
||||||
|
|
||||||
|
.toast-icon {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.toast-desc {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin: 0;
|
||||||
|
color: var(--text-color);
|
||||||
|
min-width: 240px;
|
||||||
|
}
|
81
lama_cleaner/app/src/components/shared/Toast.tsx
Normal file
81
lama_cleaner/app/src/components/shared/Toast.tsx
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import * as React from 'react'
|
||||||
|
import * as ToastPrimitive from '@radix-ui/react-toast'
|
||||||
|
import { ToastProps } from '@radix-ui/react-toast'
|
||||||
|
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline'
|
||||||
|
|
||||||
|
const LoadingIcon = () => {
|
||||||
|
return (
|
||||||
|
<span className="loading-icon">
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
width="20"
|
||||||
|
height="20"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeWidth="2"
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
>
|
||||||
|
<line x1="12" y1="2" x2="12" y2="6" />
|
||||||
|
<line x1="12" y1="18" x2="12" y2="22" />
|
||||||
|
<line x1="4.93" y1="4.93" x2="7.76" y2="7.76" />
|
||||||
|
<line x1="16.24" y1="16.24" x2="19.07" y2="19.07" />
|
||||||
|
<line x1="2" y1="12" x2="6" y2="12" />
|
||||||
|
<line x1="18" y1="12" x2="22" y2="12" />
|
||||||
|
<line x1="4.93" y1="19.07" x2="7.76" y2="16.24" />
|
||||||
|
<line x1="16.24" y1="7.76" x2="19.07" y2="4.93" />
|
||||||
|
</svg>
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ToastState = 'default' | 'error' | 'loading' | 'success'
|
||||||
|
|
||||||
|
interface MyToastProps extends ToastProps {
|
||||||
|
desc: string
|
||||||
|
state?: ToastState
|
||||||
|
}
|
||||||
|
|
||||||
|
const Toast = React.forwardRef<
|
||||||
|
React.ElementRef<typeof ToastPrimitive.Root>,
|
||||||
|
MyToastProps
|
||||||
|
>((props, forwardedRef) => {
|
||||||
|
const { state, desc, ...itemProps } = props
|
||||||
|
|
||||||
|
const getIcon = () => {
|
||||||
|
switch (state) {
|
||||||
|
case 'error':
|
||||||
|
return <ExclamationCircleIcon className="error-icon" />
|
||||||
|
case 'success':
|
||||||
|
return <CheckIcon className="success-icon" />
|
||||||
|
case 'loading':
|
||||||
|
return <LoadingIcon />
|
||||||
|
default:
|
||||||
|
return <></>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ToastPrimitive.Provider>
|
||||||
|
<ToastPrimitive.Root
|
||||||
|
{...itemProps}
|
||||||
|
ref={forwardedRef}
|
||||||
|
className={`toast-root ${state}`}
|
||||||
|
>
|
||||||
|
<div className="toast-icon">{getIcon()}</div>
|
||||||
|
<ToastPrimitive.Description className="toast-desc">
|
||||||
|
{desc}
|
||||||
|
</ToastPrimitive.Description>
|
||||||
|
</ToastPrimitive.Root>
|
||||||
|
<ToastPrimitive.Viewport className="toast-viewpoint" />
|
||||||
|
</ToastPrimitive.Provider>
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Toast.defaultProps = {
|
||||||
|
desc: '',
|
||||||
|
state: 'loading',
|
||||||
|
}
|
||||||
|
|
||||||
|
export default Toast
|
@ -1,18 +1,36 @@
|
|||||||
import { atom } from 'recoil'
|
import { atom } from 'recoil'
|
||||||
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
||||||
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||||
|
import { ToastState } from '../components/shared/Toast'
|
||||||
|
|
||||||
export const fileState = atom<File | undefined>({
|
export const fileState = atom<File | undefined>({
|
||||||
key: 'fileState',
|
key: 'fileState',
|
||||||
default: undefined,
|
default: undefined,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
interface ToastAtomState {
|
||||||
|
open: boolean
|
||||||
|
desc: string
|
||||||
|
state: ToastState
|
||||||
|
duration: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export const toastState = atom<ToastAtomState>({
|
||||||
|
key: 'toastState',
|
||||||
|
default: {
|
||||||
|
open: false,
|
||||||
|
desc: '',
|
||||||
|
state: 'default',
|
||||||
|
duration: 3000,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const shortcutsState = atom<boolean>({
|
export const shortcutsState = atom<boolean>({
|
||||||
key: 'shortcutsState',
|
key: 'shortcutsState',
|
||||||
default: false,
|
default: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
export interface Setting {
|
export interface Settings {
|
||||||
show: boolean
|
show: boolean
|
||||||
saveImageBesideOrigin: boolean
|
saveImageBesideOrigin: boolean
|
||||||
model: AIModel
|
model: AIModel
|
||||||
@ -27,16 +45,18 @@ export interface Setting {
|
|||||||
ldmSteps: number
|
ldmSteps: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export const settingState = atom<Setting>({
|
export const settingStateDefault = {
|
||||||
|
show: false,
|
||||||
|
saveImageBesideOrigin: false,
|
||||||
|
model: AIModel.LAMA,
|
||||||
|
ldmSteps: 50,
|
||||||
|
hdStrategy: HDStrategy.RESIZE,
|
||||||
|
hdStrategyResizeLimit: 2048,
|
||||||
|
hdStrategyCropTrigerSize: 2048,
|
||||||
|
hdStrategyCropMargin: 128,
|
||||||
|
}
|
||||||
|
|
||||||
|
export const settingState = atom<Settings>({
|
||||||
key: 'settingsState',
|
key: 'settingsState',
|
||||||
default: {
|
default: settingStateDefault,
|
||||||
show: false,
|
|
||||||
saveImageBesideOrigin: false,
|
|
||||||
model: AIModel.LAMA,
|
|
||||||
hdStrategy: HDStrategy.RESIZE,
|
|
||||||
hdStrategyResizeLimit: 2048,
|
|
||||||
hdStrategyCropTrigerSize: 2048,
|
|
||||||
hdStrategyCropMargin: 128,
|
|
||||||
ldmSteps: 50,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
|
@ -37,3 +37,21 @@
|
|||||||
transform: translateY(0);
|
transform: translateY(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@keyframes slideIn {
|
||||||
|
0% {
|
||||||
|
transform: translateX(calc(100% + 25px));
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: translateX(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
0% {
|
||||||
|
transform: rotate(0deg);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -7,6 +7,10 @@
|
|||||||
--yellow-accent: #ffcc00;
|
--yellow-accent: #ffcc00;
|
||||||
--link-color: rgb(0, 0, 0);
|
--link-color: rgb(0, 0, 0);
|
||||||
--border-color: rgb(100, 100, 120);
|
--border-color: rgb(100, 100, 120);
|
||||||
|
--border-color-light: rgba(100, 100, 120, 0.5);
|
||||||
|
|
||||||
|
--error-color: rgb(239, 68, 68);
|
||||||
|
--success-color: rgb(16, 185, 129);
|
||||||
|
|
||||||
// Editor
|
// Editor
|
||||||
--editor-toolkit-bg: rgba(255, 255, 255, 0.5);
|
--editor-toolkit-bg: rgba(255, 255, 255, 0.5);
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
--yellow-accent: #ffcc00;
|
--yellow-accent: #ffcc00;
|
||||||
--link-color: var(--yellow-accent);
|
--link-color: var(--yellow-accent);
|
||||||
--border-color: rgb(100, 100, 120);
|
--border-color: rgb(100, 100, 120);
|
||||||
|
--border-color-light: rgba(102, 102, 102);
|
||||||
|
|
||||||
// Editor
|
// Editor
|
||||||
--editor-toolkit-bg: rgba(0, 0, 0, 0.5);
|
--editor-toolkit-bg: rgba(0, 0, 0, 0.5);
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
@use '../components/shared/Selector';
|
@use '../components/shared/Selector';
|
||||||
@use '../components/shared/Switch';
|
@use '../components/shared/Switch';
|
||||||
@use '../components/shared/NumberInput';
|
@use '../components/shared/NumberInput';
|
||||||
|
@use '../components/shared/Toast';
|
||||||
|
|
||||||
// Main CSS
|
// Main CSS
|
||||||
*,
|
*,
|
||||||
|
@ -1565,6 +1565,19 @@
|
|||||||
dependencies:
|
dependencies:
|
||||||
"@babel/runtime" "^7.13.10"
|
"@babel/runtime" "^7.13.10"
|
||||||
|
|
||||||
|
"@radix-ui/react-dismissable-layer@0.1.5":
|
||||||
|
version "0.1.5"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-0.1.5.tgz#9379032351e79028d472733a5cc8ba4a0ea43314"
|
||||||
|
integrity sha512-J+fYWijkX4M4QKwf9dtu1oC0U6e6CEl8WhBp3Ad23yz2Hia0XCo6Pk/mp5CAFy4QBtQedTSkhW05AdtSOEoajQ==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/primitive" "0.1.0"
|
||||||
|
"@radix-ui/react-compose-refs" "0.1.0"
|
||||||
|
"@radix-ui/react-primitive" "0.1.4"
|
||||||
|
"@radix-ui/react-use-body-pointer-events" "0.1.1"
|
||||||
|
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||||
|
"@radix-ui/react-use-escape-keydown" "0.1.0"
|
||||||
|
|
||||||
"@radix-ui/react-id@0.1.5":
|
"@radix-ui/react-id@0.1.5":
|
||||||
version "0.1.5"
|
version "0.1.5"
|
||||||
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
|
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
|
||||||
@ -1584,6 +1597,24 @@
|
|||||||
"@radix-ui/react-id" "0.1.5"
|
"@radix-ui/react-id" "0.1.5"
|
||||||
"@radix-ui/react-primitive" "0.1.4"
|
"@radix-ui/react-primitive" "0.1.4"
|
||||||
|
|
||||||
|
"@radix-ui/react-portal@0.1.4":
|
||||||
|
version "0.1.4"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2"
|
||||||
|
integrity sha512-MO0wRy2eYRTZ/CyOri9NANCAtAtq89DEtg90gicaTlkCfdqCLEBsLb+/q66BZQTr3xX/Vq01nnVfc/TkCqoqvw==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/react-primitive" "0.1.4"
|
||||||
|
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||||
|
|
||||||
|
"@radix-ui/react-presence@0.1.2":
|
||||||
|
version "0.1.2"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3"
|
||||||
|
integrity sha512-3BRlFZraooIUfRlyN+b/Xs5hq1lanOOo/+3h6Pwu2GMFjkGKKa4Rd51fcqGqnVlbr3jYg+WLuGyAV4KlgqwrQw==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/react-compose-refs" "0.1.0"
|
||||||
|
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||||
|
|
||||||
"@radix-ui/react-primitive@0.1.4":
|
"@radix-ui/react-primitive@0.1.4":
|
||||||
version "0.1.4"
|
version "0.1.4"
|
||||||
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
|
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
|
||||||
@ -1615,6 +1646,32 @@
|
|||||||
"@radix-ui/react-use-previous" "0.1.1"
|
"@radix-ui/react-use-previous" "0.1.1"
|
||||||
"@radix-ui/react-use-size" "0.1.1"
|
"@radix-ui/react-use-size" "0.1.1"
|
||||||
|
|
||||||
|
"@radix-ui/react-toast@^0.1.1":
|
||||||
|
version "0.1.1"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-toast/-/react-toast-0.1.1.tgz#d544e796b307e56f1298e40f356f468680958e93"
|
||||||
|
integrity sha512-9JWC4mPP78OE6muDrpaPf/71dIeozppdcnik1IvsjTxZpDnt9PbTtQj94DdWjlCphbv3S5faD3KL0GOpqKBpTQ==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/primitive" "0.1.0"
|
||||||
|
"@radix-ui/react-compose-refs" "0.1.0"
|
||||||
|
"@radix-ui/react-context" "0.1.1"
|
||||||
|
"@radix-ui/react-dismissable-layer" "0.1.5"
|
||||||
|
"@radix-ui/react-portal" "0.1.4"
|
||||||
|
"@radix-ui/react-presence" "0.1.2"
|
||||||
|
"@radix-ui/react-primitive" "0.1.4"
|
||||||
|
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||||
|
"@radix-ui/react-use-controllable-state" "0.1.0"
|
||||||
|
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||||
|
"@radix-ui/react-visually-hidden" "0.1.4"
|
||||||
|
|
||||||
|
"@radix-ui/react-use-body-pointer-events@0.1.1":
|
||||||
|
version "0.1.1"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-use-body-pointer-events/-/react-use-body-pointer-events-0.1.1.tgz#63e7fd81ca7ffd30841deb584cd2b7f460df2597"
|
||||||
|
integrity sha512-R8leV2AWmJokTmERM8cMXFHWSiv/fzOLhG/JLmRBhLTAzOj37EQizssq4oW0Z29VcZy2tODMi9Pk/htxwb+xpA==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||||
|
|
||||||
"@radix-ui/react-use-callback-ref@0.1.0":
|
"@radix-ui/react-use-callback-ref@0.1.0":
|
||||||
version "0.1.0"
|
version "0.1.0"
|
||||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f"
|
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f"
|
||||||
@ -1630,6 +1687,14 @@
|
|||||||
"@babel/runtime" "^7.13.10"
|
"@babel/runtime" "^7.13.10"
|
||||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||||
|
|
||||||
|
"@radix-ui/react-use-escape-keydown@0.1.0":
|
||||||
|
version "0.1.0"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874"
|
||||||
|
integrity sha512-tDLZbTGFmvXaazUXXv8kYbiCcbAE8yKgng9s95d8fCO+Eundv0Jngbn/hKPhDDs4jj9ChwRX5cDDnlaN+ugYYQ==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||||
|
|
||||||
"@radix-ui/react-use-layout-effect@0.1.0":
|
"@radix-ui/react-use-layout-effect@0.1.0":
|
||||||
version "0.1.0"
|
version "0.1.0"
|
||||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
|
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
|
||||||
@ -1651,6 +1716,14 @@
|
|||||||
dependencies:
|
dependencies:
|
||||||
"@babel/runtime" "^7.13.10"
|
"@babel/runtime" "^7.13.10"
|
||||||
|
|
||||||
|
"@radix-ui/react-visually-hidden@0.1.4":
|
||||||
|
version "0.1.4"
|
||||||
|
resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03"
|
||||||
|
integrity sha512-K/q6AEEzqeeEq/T0NPChvBqnwlp8Tl4NnQdrI/y8IOY7BRR+Ug0PEsVk6g48HJ7cA1//COugdxXXVVK/m0X1mA==
|
||||||
|
dependencies:
|
||||||
|
"@babel/runtime" "^7.13.10"
|
||||||
|
"@radix-ui/react-primitive" "0.1.4"
|
||||||
|
|
||||||
"@rollup/plugin-node-resolve@^7.1.1":
|
"@rollup/plugin-node-resolve@^7.1.1":
|
||||||
version "7.1.3"
|
version "7.1.3"
|
||||||
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"
|
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"
|
||||||
|
@ -9,7 +9,7 @@ import torch
|
|||||||
from torch.hub import download_url_to_file, get_dir
|
from torch.hub import download_url_to_file, get_dir
|
||||||
|
|
||||||
|
|
||||||
def download_model(url):
|
def get_cache_path_by_url(url):
|
||||||
parts = urlparse(url)
|
parts = urlparse(url)
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||||
@ -17,6 +17,11 @@ def download_model(url):
|
|||||||
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
||||||
filename = os.path.basename(parts.path)
|
filename = os.path.basename(parts.path)
|
||||||
cached_file = os.path.join(model_dir, filename)
|
cached_file = os.path.join(model_dir, filename)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(url):
|
||||||
|
cached_file = get_cache_path_by_url(url)
|
||||||
if not os.path.exists(cached_file):
|
if not os.path.exists(cached_file):
|
||||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||||
hash_prefix = None
|
hash_prefix = None
|
||||||
|
@ -24,6 +24,11 @@ class InpaintModel:
|
|||||||
def init_model(self, device):
|
def init_model(self, device):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
|
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
@ -36,12 +36,16 @@ class LaMa(InpaintModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = download_model(LAMA_MODEL_URL)
|
model_path = download_model(LAMA_MODEL_URL)
|
||||||
|
|
||||||
logger.info(f"Load LaMa model from: {model_path}")
|
logger.info(f"Load LaMa model from: {model_path}")
|
||||||
model = torch.jit.load(model_path, map_location="cpu")
|
model = torch.jit.load(model_path, map_location="cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
|
@ -10,7 +10,7 @@ from lama_cleaner.schema import Config
|
|||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from lama_cleaner.helper import download_model, norm_img
|
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||||
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||||
timestep_embedding
|
timestep_embedding
|
||||||
|
|
||||||
@ -266,6 +266,7 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
def load_jit_model(url, device):
|
def load_jit_model(url, device):
|
||||||
model_path = download_model(url)
|
model_path = download_model(url)
|
||||||
|
logger.info(f"Load LDM model from: {model_path}")
|
||||||
model = torch.jit.load(model_path).to(device)
|
model = torch.jit.load(model_path).to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
@ -286,6 +287,15 @@ class LDM(InpaintModel):
|
|||||||
model = LatentDiffusion(self.diffusion_model, device)
|
model = LatentDiffusion(self.diffusion_model, device)
|
||||||
self.sampler = DDIMSampler(model)
|
self.sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
model_paths = [
|
||||||
|
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
||||||
|
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
||||||
|
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
||||||
|
]
|
||||||
|
return all([os.path.exists(it) for it in model_paths])
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""
|
"""
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
|
@ -21,6 +21,14 @@ class ModelManager:
|
|||||||
raise NotImplementedError(f"Not supported model: {name}")
|
raise NotImplementedError(f"Not supported model: {name}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def is_downloaded(self, name: str) -> bool:
|
||||||
|
if name == self.LAMA:
|
||||||
|
return LaMa.is_downloaded()
|
||||||
|
elif name == self.LDM:
|
||||||
|
return LDM.is_downloaded()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Not supported model: {name}")
|
||||||
|
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
return self.model(image, mask, config)
|
return self.model(image, mask, config)
|
||||||
|
|
||||||
|
12
main.py
12
main.py
@ -125,7 +125,17 @@ def process():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/switch_model", methods=["POST"])
|
@app.route("/model")
|
||||||
|
def current_model():
|
||||||
|
return model.name, 200
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/model_downloaded/<name>")
|
||||||
|
def model_downloaded(name):
|
||||||
|
return str(model.is_downloaded(name)), 200
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/model", methods=["POST"])
|
||||||
def switch_model():
|
def switch_model():
|
||||||
new_name = request.form.get("name")
|
new_name = request.form.get("name")
|
||||||
if new_name == model.name:
|
if new_name == model.name:
|
||||||
|
Loading…
Reference in New Issue
Block a user