big update
@ -1,10 +1,12 @@
|
|||||||
|
import { Setting } from '../store/Atoms'
|
||||||
import { dataURItoBlob } from '../utils'
|
import { dataURItoBlob } from '../utils'
|
||||||
|
|
||||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
|
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||||
|
|
||||||
export default async function inpaint(
|
export default async function inpaint(
|
||||||
imageFile: File,
|
imageFile: File,
|
||||||
maskBase64: string,
|
maskBase64: string,
|
||||||
|
settings: Setting,
|
||||||
sizeLimit?: string
|
sizeLimit?: string
|
||||||
) {
|
) {
|
||||||
// 1080, 2000, Original
|
// 1080, 2000, Original
|
||||||
@ -13,13 +15,22 @@ export default async function inpaint(
|
|||||||
const mask = dataURItoBlob(maskBase64)
|
const mask = dataURItoBlob(maskBase64)
|
||||||
fd.append('mask', mask)
|
fd.append('mask', mask)
|
||||||
|
|
||||||
|
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||||
|
fd.append('hdStrategy', settings.hdStrategy)
|
||||||
|
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
||||||
|
fd.append(
|
||||||
|
'hdStrategyCropTrigerSize',
|
||||||
|
settings.hdStrategyCropTrigerSize.toString()
|
||||||
|
)
|
||||||
|
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
|
||||||
|
|
||||||
if (sizeLimit === undefined) {
|
if (sizeLimit === undefined) {
|
||||||
fd.append('sizeLimit', '1080')
|
fd.append('sizeLimit', '1080')
|
||||||
} else {
|
} else {
|
||||||
fd.append('sizeLimit', sizeLimit)
|
fd.append('sizeLimit', sizeLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = await fetch(API_ENDPOINT, {
|
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: fd,
|
body: fd,
|
||||||
}).then(async r => {
|
}).then(async r => {
|
||||||
@ -28,3 +39,12 @@ export default async function inpaint(
|
|||||||
|
|
||||||
return URL.createObjectURL(res)
|
return URL.createObjectURL(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function switchModel(name: string) {
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append('name', name)
|
||||||
|
return fetch(`${API_ENDPOINT}/switch_model`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: fd,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -15,12 +15,14 @@ import {
|
|||||||
TransformComponent,
|
TransformComponent,
|
||||||
TransformWrapper,
|
TransformWrapper,
|
||||||
} from 'react-zoom-pan-pinch'
|
} from 'react-zoom-pan-pinch'
|
||||||
|
import { useRecoilValue } from 'recoil'
|
||||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||||
import inpaint from '../../adapters/inpainting'
|
import inpaint from '../../adapters/inpainting'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
import Slider from './Slider'
|
import Slider from './Slider'
|
||||||
import SizeSelector from './SizeSelector'
|
import SizeSelector from './SizeSelector'
|
||||||
import { downloadImage, loadImage, useImage } from '../../utils'
|
import { downloadImage, loadImage, useImage } from '../../utils'
|
||||||
|
import { settingState } from '../../store/Atoms'
|
||||||
|
|
||||||
const TOOLBAR_SIZE = 200
|
const TOOLBAR_SIZE = 200
|
||||||
const BRUSH_COLOR = '#ffcc00bb'
|
const BRUSH_COLOR = '#ffcc00bb'
|
||||||
@ -57,6 +59,7 @@ function drawLines(
|
|||||||
|
|
||||||
export default function Editor(props: EditorProps) {
|
export default function Editor(props: EditorProps) {
|
||||||
const { file } = props
|
const { file } = props
|
||||||
|
const settings = useRecoilValue(settingState)
|
||||||
const [brushSize, setBrushSize] = useState(40)
|
const [brushSize, setBrushSize] = useState(40)
|
||||||
const [original, isOriginalLoaded] = useImage(file)
|
const [original, isOriginalLoaded] = useImage(file)
|
||||||
const [renders, setRenders] = useState<HTMLImageElement[]>([])
|
const [renders, setRenders] = useState<HTMLImageElement[]>([])
|
||||||
@ -125,6 +128,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
const res = await inpaint(
|
const res = await inpaint(
|
||||||
file,
|
file,
|
||||||
maskCanvas.toDataURL(),
|
maskCanvas.toDataURL(),
|
||||||
|
settings,
|
||||||
sizeLimit.toString()
|
sizeLimit.toString()
|
||||||
)
|
)
|
||||||
if (!res) {
|
if (!res) {
|
||||||
@ -157,6 +161,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
renders,
|
renders,
|
||||||
sizeLimit,
|
sizeLimit,
|
||||||
historyLineCount,
|
historyLineCount,
|
||||||
|
settings,
|
||||||
])
|
])
|
||||||
|
|
||||||
const hadDrawSomething = () => {
|
const hadDrawSomething = () => {
|
||||||
|
@ -6,7 +6,7 @@ import Button from '../shared/Button'
|
|||||||
import Shortcuts from '../Shortcuts/Shortcuts'
|
import Shortcuts from '../Shortcuts/Shortcuts'
|
||||||
import useResolution from '../../hooks/useResolution'
|
import useResolution from '../../hooks/useResolution'
|
||||||
import { ThemeChanger } from './ThemeChanger'
|
import { ThemeChanger } from './ThemeChanger'
|
||||||
import SettingIcon from '../Setting/SettingIcon'
|
import SettingIcon from '../Settings/SettingIcon'
|
||||||
|
|
||||||
const Header = () => {
|
const Header = () => {
|
||||||
const [file, setFile] = useRecoilState(fileState)
|
const [file, setFile] = useRecoilState(fileState)
|
||||||
|
@ -1,50 +1,16 @@
|
|||||||
import React, { ReactNode } from 'react'
|
import React, { ReactNode } from 'react'
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import NumberInput from '../shared/NumberInput'
|
|
||||||
import Selector from '../shared/Selector'
|
import Selector from '../shared/Selector'
|
||||||
|
import NumberInputSetting from './NumberInputSetting'
|
||||||
import SettingBlock from './SettingBlock'
|
import SettingBlock from './SettingBlock'
|
||||||
|
|
||||||
export enum HDStrategy {
|
export enum HDStrategy {
|
||||||
ORIGINAL = 'Original',
|
ORIGINAL = 'Original',
|
||||||
REISIZE = 'Resize',
|
RESIZE = 'Resize',
|
||||||
CROP = 'Crop',
|
CROP = 'Crop',
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PixelSizeInputProps {
|
|
||||||
title: string
|
|
||||||
value: string
|
|
||||||
onValue: (val: string) => void
|
|
||||||
}
|
|
||||||
|
|
||||||
function PixelSizeInputSetting(props: PixelSizeInputProps) {
|
|
||||||
const { title, value, onValue } = props
|
|
||||||
|
|
||||||
return (
|
|
||||||
<SettingBlock
|
|
||||||
className="sub-setting-block"
|
|
||||||
title={title}
|
|
||||||
input={
|
|
||||||
<div
|
|
||||||
style={{
|
|
||||||
display: 'flex',
|
|
||||||
justifyContent: 'center',
|
|
||||||
alignItems: 'center',
|
|
||||||
gap: '8px',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<NumberInput
|
|
||||||
style={{ width: '80px' }}
|
|
||||||
value={`${value}`}
|
|
||||||
onValue={onValue}
|
|
||||||
/>
|
|
||||||
<span>pixel</span>
|
|
||||||
</div>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
function HDSettingBlock() {
|
function HDSettingBlock() {
|
||||||
const [setting, setSettingState] = useRecoilState(settingState)
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
|
||||||
@ -84,7 +50,7 @@ function HDSettingBlock() {
|
|||||||
tabIndex={0}
|
tabIndex={0}
|
||||||
role="button"
|
role="button"
|
||||||
className="inline-tip"
|
className="inline-tip"
|
||||||
onClick={() => onStrategyChange(HDStrategy.REISIZE)}
|
onClick={() => onStrategyChange(HDStrategy.RESIZE)}
|
||||||
>
|
>
|
||||||
Resize Strategy
|
Resize Strategy
|
||||||
</div>{' '}
|
</div>{' '}
|
||||||
@ -100,9 +66,10 @@ function HDSettingBlock() {
|
|||||||
Resize the longer side of the image to a specific size(keep ratio),
|
Resize the longer side of the image to a specific size(keep ratio),
|
||||||
then do inpainting on the resized image.
|
then do inpainting on the resized image.
|
||||||
</div>
|
</div>
|
||||||
<PixelSizeInputSetting
|
<NumberInputSetting
|
||||||
title="Size limit"
|
title="Size limit"
|
||||||
value={`${setting.hdStrategyResizeLimit}`}
|
value={`${setting.hdStrategyResizeLimit}`}
|
||||||
|
suffix="pixel"
|
||||||
onValue={onResizeLimitChange}
|
onValue={onResizeLimitChange}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@ -117,14 +84,16 @@ function HDSettingBlock() {
|
|||||||
the result back. Mainly for performance and memory reasons on high
|
the result back. Mainly for performance and memory reasons on high
|
||||||
resolution image.
|
resolution image.
|
||||||
</div>
|
</div>
|
||||||
<PixelSizeInputSetting
|
<NumberInputSetting
|
||||||
title="Trigger size"
|
title="Trigger size"
|
||||||
value={`${setting.hdStrategyCropTrigerSize}`}
|
value={`${setting.hdStrategyCropTrigerSize}`}
|
||||||
|
suffix="pixel"
|
||||||
onValue={onCropTriggerSizeChange}
|
onValue={onCropTriggerSizeChange}
|
||||||
/>
|
/>
|
||||||
<PixelSizeInputSetting
|
<NumberInputSetting
|
||||||
title="Crop margin"
|
title="Crop margin"
|
||||||
value={`${setting.hdStrategyCropMargin}`}
|
value={`${setting.hdStrategyCropMargin}`}
|
||||||
|
suffix="pixel"
|
||||||
onValue={onCropMarginChange}
|
onValue={onCropMarginChange}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@ -137,7 +106,7 @@ function HDSettingBlock() {
|
|||||||
return renderOriginalOptionDesc()
|
return renderOriginalOptionDesc()
|
||||||
case HDStrategy.CROP:
|
case HDStrategy.CROP:
|
||||||
return renderCropOptionDesc()
|
return renderCropOptionDesc()
|
||||||
case HDStrategy.REISIZE:
|
case HDStrategy.RESIZE:
|
||||||
return renderResizeOptionDesc()
|
return renderResizeOptionDesc()
|
||||||
default:
|
default:
|
||||||
return renderOriginalOptionDesc()
|
return renderOriginalOptionDesc()
|
@ -2,11 +2,12 @@ import React, { ReactNode } from 'react'
|
|||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import Selector from '../shared/Selector'
|
import Selector from '../shared/Selector'
|
||||||
|
import NumberInputSetting from './NumberInputSetting'
|
||||||
import SettingBlock from './SettingBlock'
|
import SettingBlock from './SettingBlock'
|
||||||
|
|
||||||
export enum AIModel {
|
export enum AIModel {
|
||||||
LAMA = 'LaMa',
|
LAMA = 'lama',
|
||||||
LDM = 'LDM',
|
LDM = 'ldm',
|
||||||
}
|
}
|
||||||
|
|
||||||
function ModelSettingBlock() {
|
function ModelSettingBlock() {
|
||||||
@ -24,7 +25,7 @@ function ModelSettingBlock() {
|
|||||||
githubUrl: string
|
githubUrl: string
|
||||||
) => {
|
) => {
|
||||||
return (
|
return (
|
||||||
<div style={{ display: 'flex', flexDirection: 'column' }}>
|
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
|
||||||
<a
|
<a
|
||||||
className="model-desc-link"
|
className="model-desc-link"
|
||||||
href={paperUrl}
|
href={paperUrl}
|
||||||
@ -34,14 +35,11 @@ function ModelSettingBlock() {
|
|||||||
{name}
|
{name}
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
<br />
|
|
||||||
|
|
||||||
<a
|
<a
|
||||||
className="model-desc-link"
|
className="model-desc-link"
|
||||||
href={githubUrl}
|
href={githubUrl}
|
||||||
target="_blank"
|
target="_blank"
|
||||||
rel="noreferrer noopener"
|
rel="noreferrer noopener"
|
||||||
style={{ marginTop: '8px' }}
|
|
||||||
>
|
>
|
||||||
{githubUrl}
|
{githubUrl}
|
||||||
</a>
|
</a>
|
||||||
@ -49,6 +47,28 @@ function ModelSettingBlock() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderLDMModelDesc = () => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{renderModelDesc(
|
||||||
|
'High-Resolution Image Synthesis with Latent Diffusion Models',
|
||||||
|
'https://arxiv.org/abs/2112.10752',
|
||||||
|
'https://github.com/CompVis/latent-diffusion'
|
||||||
|
)}
|
||||||
|
<NumberInputSetting
|
||||||
|
title="Steps"
|
||||||
|
value={`${setting.ldmSteps}`}
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, ldmSteps: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderOptionDesc = (): ReactNode => {
|
const renderOptionDesc = (): ReactNode => {
|
||||||
switch (setting.model) {
|
switch (setting.model) {
|
||||||
case AIModel.LAMA:
|
case AIModel.LAMA:
|
||||||
@ -58,11 +78,7 @@ function ModelSettingBlock() {
|
|||||||
'https://github.com/saic-mdal/lama'
|
'https://github.com/saic-mdal/lama'
|
||||||
)
|
)
|
||||||
case AIModel.LDM:
|
case AIModel.LDM:
|
||||||
return renderModelDesc(
|
return renderLDMModelDesc()
|
||||||
'High-Resolution Image Synthesis with Latent Diffusion Models',
|
|
||||||
'https://arxiv.org/abs/2112.10752',
|
|
||||||
'https://github.com/CompVis/latent-diffusion'
|
|
||||||
)
|
|
||||||
default:
|
default:
|
||||||
return <></>
|
return <></>
|
||||||
}
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import NumberInput from '../shared/NumberInput'
|
||||||
|
import SettingBlock from './SettingBlock'
|
||||||
|
|
||||||
|
interface NumberInputSettingProps {
|
||||||
|
title: string
|
||||||
|
value: string
|
||||||
|
suffix?: string
|
||||||
|
onValue: (val: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
function NumberInputSetting(props: NumberInputSettingProps) {
|
||||||
|
const { title, value, suffix, onValue } = props
|
||||||
|
|
||||||
|
return (
|
||||||
|
<SettingBlock
|
||||||
|
className="sub-setting-block"
|
||||||
|
title={title}
|
||||||
|
input={
|
||||||
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: '8px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<NumberInput
|
||||||
|
style={{ width: '80px' }}
|
||||||
|
value={`${value}`}
|
||||||
|
onValue={onValue}
|
||||||
|
/>
|
||||||
|
{suffix && <span>{suffix}</span>}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default NumberInputSetting
|
@ -1,14 +1,26 @@
|
|||||||
import React, { ReactNode } from 'react'
|
import React, { ReactNode } from 'react'
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
|
import { settingState } from '../../store/Atoms'
|
||||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||||
import SettingBlock from './SettingBlock'
|
import SettingBlock from './SettingBlock'
|
||||||
|
|
||||||
function SavePathSettingBlock() {
|
function SavePathSettingBlock() {
|
||||||
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
|
||||||
|
const onCheckChange = (checked: boolean) => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, saveImageBesideOrigin: checked }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<SettingBlock
|
<SettingBlock
|
||||||
title="Download image beside origin image"
|
title="Download image beside origin image"
|
||||||
input={
|
input={
|
||||||
<Switch defaultChecked>
|
<Switch
|
||||||
|
checked={setting.saveImageBesideOrigin}
|
||||||
|
onCheckedChange={onCheckChange}
|
||||||
|
>
|
||||||
<SwitchThumb />
|
<SwitchThumb />
|
||||||
</Switch>
|
</Switch>
|
||||||
}
|
}
|
@ -7,7 +7,7 @@
|
|||||||
margin-top: 12px;
|
margin-top: 12px;
|
||||||
border: 1px solid var(--border-color);
|
border: 1px solid var(--border-color);
|
||||||
border-radius: 0.3rem;
|
border-radius: 0.3rem;
|
||||||
padding: 2rem;
|
padding: 1rem;
|
||||||
|
|
||||||
.sub-setting-block {
|
.sub-setting-block {
|
||||||
margin-top: 8px;
|
margin-top: 8px;
|
@ -8,7 +8,6 @@
|
|||||||
background-color: var(--modal-bg);
|
background-color: var(--modal-bg);
|
||||||
color: var(--modal-text-color);
|
color: var(--modal-text-color);
|
||||||
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
||||||
min-height: 600px;
|
|
||||||
width: 700px;
|
width: 700px;
|
||||||
|
|
||||||
@include mobile {
|
@include mobile {
|
@ -1,11 +1,11 @@
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
|
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
|
import { switchModel } from '../../adapters/inpainting'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import Modal from '../shared/Modal'
|
import Modal from '../shared/Modal'
|
||||||
import HDSettingBlock from './HDSettingBlock'
|
import HDSettingBlock from './HDSettingBlock'
|
||||||
import ModelSettingBlock from './ModelSettingBlock'
|
import ModelSettingBlock from './ModelSettingBlock'
|
||||||
import SavePathSettingBlock from './SavePathSettingBlock'
|
|
||||||
|
|
||||||
export default function SettingModal() {
|
export default function SettingModal() {
|
||||||
const [setting, setSettingState] = useRecoilState(settingState)
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
@ -14,6 +14,8 @@ export default function SettingModal() {
|
|||||||
setSettingState(old => {
|
setSettingState(old => {
|
||||||
return { ...old, show: false }
|
return { ...old, show: false }
|
||||||
})
|
})
|
||||||
|
|
||||||
|
switchModel(setting.model)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -23,7 +25,9 @@ export default function SettingModal() {
|
|||||||
className="modal-setting"
|
className="modal-setting"
|
||||||
show={setting.show}
|
show={setting.show}
|
||||||
>
|
>
|
||||||
<SavePathSettingBlock />
|
{/* It's not possible because this poses a security risk */}
|
||||||
|
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
|
||||||
|
{/* <SavePathSettingBlock /> */}
|
||||||
<ModelSettingBlock />
|
<ModelSettingBlock />
|
||||||
<HDSettingBlock />
|
<HDSettingBlock />
|
||||||
</Modal>
|
</Modal>
|
@ -1,9 +1,7 @@
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
import { useRecoilValue } from 'recoil'
|
|
||||||
import Editor from './Editor/Editor'
|
import Editor from './Editor/Editor'
|
||||||
import { shortcutsState } from '../store/Atoms'
|
|
||||||
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
||||||
import SettingModal from './Setting/SettingModal'
|
import SettingModal from './Settings/SettingsModal'
|
||||||
|
|
||||||
interface WorkspaceProps {
|
interface WorkspaceProps {
|
||||||
file: File
|
file: File
|
||||||
|
@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
|
|||||||
const ref = useRef(null)
|
const ref = useRef(null)
|
||||||
|
|
||||||
useClickAway(ref, () => {
|
useClickAway(ref, () => {
|
||||||
|
if (show) {
|
||||||
onClose?.()
|
onClose?.()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
useKeyPressEvent('Escape', e => {
|
useKeyPressEvent('Escape', e => {
|
||||||
|
if (show) {
|
||||||
onClose?.()
|
onClose?.()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { atom } from 'recoil'
|
import { atom } from 'recoil'
|
||||||
import { HDStrategy } from '../components/Setting/HDSettingBlock'
|
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
||||||
import { AIModel } from '../components/Setting/ModelSettingBlock'
|
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||||
|
|
||||||
export const fileState = atom<File | undefined>({
|
export const fileState = atom<File | undefined>({
|
||||||
key: 'fileState',
|
key: 'fileState',
|
||||||
@ -16,10 +16,15 @@ export interface Setting {
|
|||||||
show: boolean
|
show: boolean
|
||||||
saveImageBesideOrigin: boolean
|
saveImageBesideOrigin: boolean
|
||||||
model: AIModel
|
model: AIModel
|
||||||
|
|
||||||
|
// For LaMa
|
||||||
hdStrategy: HDStrategy
|
hdStrategy: HDStrategy
|
||||||
hdStrategyResizeLimit: number
|
hdStrategyResizeLimit: number
|
||||||
hdStrategyCropTrigerSize: number
|
hdStrategyCropTrigerSize: number
|
||||||
hdStrategyCropMargin: number
|
hdStrategyCropMargin: number
|
||||||
|
|
||||||
|
// For LDM
|
||||||
|
ldmSteps: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export const settingState = atom<Setting>({
|
export const settingState = atom<Setting>({
|
||||||
@ -28,9 +33,10 @@ export const settingState = atom<Setting>({
|
|||||||
show: false,
|
show: false,
|
||||||
saveImageBesideOrigin: false,
|
saveImageBesideOrigin: false,
|
||||||
model: AIModel.LAMA,
|
model: AIModel.LAMA,
|
||||||
hdStrategy: HDStrategy.ORIGINAL,
|
hdStrategy: HDStrategy.RESIZE,
|
||||||
hdStrategyResizeLimit: 2048,
|
hdStrategyResizeLimit: 2048,
|
||||||
hdStrategyCropTrigerSize: 2048,
|
hdStrategyCropTrigerSize: 2048,
|
||||||
hdStrategyCropMargin: 128,
|
hdStrategyCropMargin: 128,
|
||||||
|
ldmSteps: 50,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
@use '../components/Header/Header';
|
@use '../components/Header/Header';
|
||||||
@use '../components/Header/ThemeChanger';
|
@use '../components/Header/ThemeChanger';
|
||||||
@use '../components/Shortcuts/Shortcuts';
|
@use '../components/Shortcuts/Shortcuts';
|
||||||
@use '../components/Setting/Setting.scss';
|
@use '../components/Settings/Settings.scss';
|
||||||
|
|
||||||
// Shared
|
// Shared
|
||||||
@use '../components/FileSelect/FileSelect';
|
@use '../components/FileSelect/FileSelect';
|
||||||
|
@ -31,7 +31,11 @@ def ceil_modulo(x, mod):
|
|||||||
|
|
||||||
|
|
||||||
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||||
data = cv2.imencode(f".{ext}", image_numpy)[1]
|
data = cv2.imencode(f".{ext}", image_numpy,
|
||||||
|
[
|
||||||
|
int(cv2.IMWRITE_JPEG_QUALITY), 100,
|
||||||
|
int(cv2.IMWRITE_PNG_COMPRESSION), 0
|
||||||
|
])[1]
|
||||||
image_bytes = data.tobytes()
|
image_bytes = data.tobytes()
|
||||||
return image_bytes
|
return image_bytes
|
||||||
|
|
||||||
@ -74,13 +78,24 @@ def resize_max_size(
|
|||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
def pad_img_to_modulo(img, mod):
|
def pad_img_to_modulo(img: np.ndarray, mod: int):
|
||||||
channels, height, width = img.shape
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: [H, W, C]
|
||||||
|
mod:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(img.shape) == 2:
|
||||||
|
img = img[:, :, np.newaxis]
|
||||||
|
height, width = img.shape[:2]
|
||||||
out_height = ceil_modulo(height, mod)
|
out_height = ceil_modulo(height, mod)
|
||||||
out_width = ceil_modulo(width, mod)
|
out_width = ceil_modulo(width, mod)
|
||||||
return np.pad(
|
return np.pad(
|
||||||
img,
|
img,
|
||||||
((0, 0), (0, out_height - height), (0, out_width - width)),
|
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||||
mode="symmetric",
|
mode="symmetric",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,15 +103,13 @@ def pad_img_to_modulo(img, mod):
|
|||||||
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
mask: (1, h, w) 0~1
|
mask: (h, w, 1) 0~255
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
height, width = mask.shape[1:]
|
height, width = mask.shape[:2]
|
||||||
_, thresh = cv2.threshold(
|
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||||
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
|
|
||||||
)
|
|
||||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
boxes = []
|
boxes = []
|
||||||
|
@ -1,121 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
|
|
||||||
|
|
||||||
LAMA_MODEL_URL = os.environ.get(
|
|
||||||
"LAMA_MODEL_URL",
|
|
||||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LaMa:
|
|
||||||
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
crop_trigger_size: h, w
|
|
||||||
crop_margin:
|
|
||||||
device:
|
|
||||||
"""
|
|
||||||
self.crop_trigger_size = crop_trigger_size
|
|
||||||
self.crop_margin = crop_margin
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
if os.environ.get("LAMA_MODEL"):
|
|
||||||
model_path = os.environ.get("LAMA_MODEL")
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"lama torchscript model not found: {model_path}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_path = download_model(LAMA_MODEL_URL)
|
|
||||||
|
|
||||||
print(f"Load LaMa model from: {model_path}")
|
|
||||||
model = torch.jit.load(model_path, map_location="cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
model.eval()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, image, mask):
|
|
||||||
"""
|
|
||||||
image: [C, H, W] RGB
|
|
||||||
mask: [1, H, W]
|
|
||||||
return: BGR IMAGE
|
|
||||||
"""
|
|
||||||
area = image.shape[1] * image.shape[2]
|
|
||||||
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
|
|
||||||
return self._run(image, mask)
|
|
||||||
|
|
||||||
print("Trigger crop image")
|
|
||||||
boxes = boxes_from_mask(mask)
|
|
||||||
crop_result = []
|
|
||||||
for box in boxes:
|
|
||||||
crop_image, crop_box = self._run_box(image, mask, box)
|
|
||||||
crop_result.append((crop_image, crop_box))
|
|
||||||
|
|
||||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
|
|
||||||
for crop_image, crop_box in crop_result:
|
|
||||||
x1, y1, x2, y2 = crop_box
|
|
||||||
image[y1:y2, x1:x2, :] = crop_image
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _run_box(self, image, mask, box):
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: [C, H, W] RGB
|
|
||||||
mask: [1, H, W]
|
|
||||||
box: [left,top,right,bottom]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BGR IMAGE
|
|
||||||
"""
|
|
||||||
box_h = box[3] - box[1]
|
|
||||||
box_w = box[2] - box[0]
|
|
||||||
cx = (box[0] + box[2]) // 2
|
|
||||||
cy = (box[1] + box[3]) // 2
|
|
||||||
img_h, img_w = image.shape[1:]
|
|
||||||
|
|
||||||
w = box_w + self.crop_margin * 2
|
|
||||||
h = box_h + self.crop_margin * 2
|
|
||||||
|
|
||||||
l = max(cx - w // 2, 0)
|
|
||||||
t = max(cy - h // 2, 0)
|
|
||||||
r = min(cx + w // 2, img_w)
|
|
||||||
b = min(cy + h // 2, img_h)
|
|
||||||
|
|
||||||
crop_img = image[:, t:b, l:r]
|
|
||||||
crop_mask = mask[:, t:b, l:r]
|
|
||||||
|
|
||||||
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
|
||||||
|
|
||||||
return self._run(crop_img, crop_mask), [l, t, r, b]
|
|
||||||
|
|
||||||
def _run(self, image, mask):
|
|
||||||
"""
|
|
||||||
image: [C, H, W] RGB
|
|
||||||
mask: [1, H, W]
|
|
||||||
return: BGR IMAGE
|
|
||||||
"""
|
|
||||||
device = self.device
|
|
||||||
origin_height, origin_width = image.shape[1:]
|
|
||||||
image = pad_img_to_modulo(image, mod=8)
|
|
||||||
mask = pad_img_to_modulo(mask, mod=8)
|
|
||||||
|
|
||||||
mask = (mask > 0) * 1
|
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
|
||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
inpainted_image = self.model(image, mask)
|
|
||||||
|
|
||||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
|
||||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
|
||||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
|
||||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
|
||||||
return cur_res
|
|
122
lama_cleaner/model/base.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
||||||
|
from lama_cleaner.schema import Config, HDStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintModel:
|
||||||
|
pad_mod = 8
|
||||||
|
|
||||||
|
def __init__(self, device):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device:
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
self.init_model(device)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def init_model(self, device):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def _pad_forward(self, image, mask, config: Config):
|
||||||
|
origin_height, origin_width = image.shape[:2]
|
||||||
|
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
|
||||||
|
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
|
||||||
|
result = self.forward(padd_image, padd_mask, config)
|
||||||
|
result = result[0:origin_height, 0:origin_width, :]
|
||||||
|
|
||||||
|
original_pixel_indices = mask != 255
|
||||||
|
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, image, mask, config: Config):
|
||||||
|
"""
|
||||||
|
image: [H, W, C] RGB, not normalized
|
||||||
|
mask: [H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
inpaint_result = None
|
||||||
|
logger.info(f"hd_strategy: {config.hd_strategy}")
|
||||||
|
if config.hd_strategy == HDStrategy.CROP:
|
||||||
|
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
||||||
|
logger.info(f"Run crop strategy")
|
||||||
|
boxes = boxes_from_mask(mask)
|
||||||
|
crop_result = []
|
||||||
|
for box in boxes:
|
||||||
|
crop_image, crop_box = self._run_box(image, mask, box, config)
|
||||||
|
crop_result.append((crop_image, crop_box))
|
||||||
|
|
||||||
|
inpaint_result = image[:, :, ::-1]
|
||||||
|
for crop_image, crop_box in crop_result:
|
||||||
|
x1, y1, x2, y2 = crop_box
|
||||||
|
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
||||||
|
|
||||||
|
elif config.hd_strategy == HDStrategy.RESIZE:
|
||||||
|
if max(image.shape) > config.hd_strategy_resize_limit:
|
||||||
|
origin_size = image.shape[:2]
|
||||||
|
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
|
||||||
|
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
|
||||||
|
|
||||||
|
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
|
||||||
|
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||||
|
|
||||||
|
# only paste masked area result
|
||||||
|
inpaint_result = cv2.resize(inpaint_result,
|
||||||
|
(origin_size[1], origin_size[0]),
|
||||||
|
interpolation=cv2.INTER_CUBIC)
|
||||||
|
|
||||||
|
original_pixel_indices = mask != 255
|
||||||
|
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||||
|
|
||||||
|
if inpaint_result is None:
|
||||||
|
inpaint_result = self._pad_forward(image, mask, config)
|
||||||
|
|
||||||
|
return inpaint_result
|
||||||
|
|
||||||
|
def _run_box(self, image, mask, box, config: Config):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1]
|
||||||
|
box: [left,top,right,bottom]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BGR IMAGE
|
||||||
|
"""
|
||||||
|
box_h = box[3] - box[1]
|
||||||
|
box_w = box[2] - box[0]
|
||||||
|
cx = (box[0] + box[2]) // 2
|
||||||
|
cy = (box[1] + box[3]) // 2
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
|
||||||
|
w = box_w + config.hd_strategy_crop_margin * 2
|
||||||
|
h = box_h + config.hd_strategy_crop_margin * 2
|
||||||
|
|
||||||
|
l = max(cx - w // 2, 0)
|
||||||
|
t = max(cy - h // 2, 0)
|
||||||
|
r = min(cx + w // 2, img_w)
|
||||||
|
b = min(cy + h // 2, img_h)
|
||||||
|
|
||||||
|
crop_img = image[t:b, l:r, :]
|
||||||
|
crop_mask = mask[t:b, l:r]
|
||||||
|
|
||||||
|
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||||
|
|
||||||
|
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
64
lama_cleaner/model/lama.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
|
||||||
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
LAMA_MODEL_URL = os.environ.get(
|
||||||
|
"LAMA_MODEL_URL",
|
||||||
|
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LaMa(InpaintModel):
|
||||||
|
pad_mod = 8
|
||||||
|
|
||||||
|
def __init__(self, device):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device:
|
||||||
|
"""
|
||||||
|
super().__init__(device)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def init_model(self, device):
|
||||||
|
if os.environ.get("LAMA_MODEL"):
|
||||||
|
model_path = os.environ.get("LAMA_MODEL")
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"lama torchscript model not found: {model_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = download_model(LAMA_MODEL_URL)
|
||||||
|
|
||||||
|
logger.info(f"Load LaMa model from: {model_path}")
|
||||||
|
model = torch.jit.load(model_path, map_location="cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
image = norm_img(image)
|
||||||
|
mask = norm_img(mask)
|
||||||
|
|
||||||
|
mask = (mask > 0) * 1
|
||||||
|
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||||
|
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
inpainted_image = self.model(image, mask)
|
||||||
|
|
||||||
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||||
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||||
|
return cur_res
|
@ -2,13 +2,16 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import cv2
|
from lama_cleaner.helper import download_model, norm_img
|
||||||
from lama_cleaner.helper import pad_img_to_modulo, download_model
|
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||||
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
|
||||||
timestep_embedding
|
timestep_embedding
|
||||||
|
|
||||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||||
@ -217,7 +220,7 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||||
|
|
||||||
@ -268,97 +271,39 @@ def load_jit_model(url, device):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class LDM:
|
class LDM(InpaintModel):
|
||||||
def __init__(self, device, steps=50):
|
pad_mod = 32
|
||||||
|
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
def init_model(self, device):
|
||||||
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
||||||
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
||||||
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
||||||
|
|
||||||
model = LatentDiffusion(self.diffusion_model, device)
|
model = LatentDiffusion(self.diffusion_model, device)
|
||||||
self.sampler = DDIMSampler(model)
|
self.sampler = DDIMSampler(model)
|
||||||
self.steps = steps
|
|
||||||
|
|
||||||
def _norm(self, tensor):
|
def forward(self, image, mask, config: Config):
|
||||||
return tensor * 2.0 - 1.0
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __call__(self, image, mask):
|
|
||||||
"""
|
"""
|
||||||
image: [C, H, W] RGB
|
image: [H, W, C] RGB
|
||||||
mask: [1, H, W]
|
mask: [H, W, 1]
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
# image [1,3,512,512] float32
|
# image [1,3,512,512] float32
|
||||||
# mask: [1,1,512,512] float32
|
# mask: [1,1,512,512] float32
|
||||||
# masked_image: [1,3,512,512] float32
|
# masked_image: [1,3,512,512] float32
|
||||||
origin_height, origin_width = image.shape[1:]
|
steps = config.ldm_steps
|
||||||
image = pad_img_to_modulo(image, mod=32)
|
image = norm_img(image)
|
||||||
mask = pad_img_to_modulo(mask, mod=32)
|
mask = norm_img(mask)
|
||||||
padded_height, padded_width = image.shape[1:]
|
|
||||||
mask[mask < 0.5] = 0
|
mask[mask < 0.5] = 0
|
||||||
mask[mask >= 0.5] = 1
|
mask[mask >= 0.5] = 1
|
||||||
|
|
||||||
# crop 512 x 512
|
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||||
if padded_width <= 512 or padded_height <= 512:
|
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||||
np_img = self._forward(image, mask, self.device)
|
|
||||||
else:
|
|
||||||
print("Try to zoom in")
|
|
||||||
# zoom in
|
|
||||||
# x,y,w,h
|
|
||||||
# box = self.box_from_bitmap(mask)
|
|
||||||
box = self.find_main_content(mask)
|
|
||||||
if box is None:
|
|
||||||
print("No bbox found")
|
|
||||||
np_img = self._forward(image, mask, self.device)
|
|
||||||
else:
|
|
||||||
print(f"box: {box}")
|
|
||||||
box_x, box_y, box_w, box_h = box
|
|
||||||
cx = box_x + box_w // 2
|
|
||||||
cy = box_y + box_h // 2
|
|
||||||
|
|
||||||
# w = max(512, box_w)
|
|
||||||
# h = max(512, box_h)
|
|
||||||
w = box_w + 512
|
|
||||||
h = box_h + 512
|
|
||||||
|
|
||||||
left = max(cx - w // 2, 0)
|
|
||||||
top = max(cy - h // 2, 0)
|
|
||||||
right = min(cx + w // 2, origin_width)
|
|
||||||
bottom = min(cy + h // 2, origin_height)
|
|
||||||
|
|
||||||
x = left
|
|
||||||
y = top
|
|
||||||
w = right - left
|
|
||||||
h = bottom - top
|
|
||||||
|
|
||||||
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
|
|
||||||
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
|
|
||||||
|
|
||||||
print(f"Apply zoom in size width x height: {crop_img.shape}")
|
|
||||||
|
|
||||||
crop_img_height, crop_img_width = crop_img.shape[1:]
|
|
||||||
|
|
||||||
crop_img = pad_img_to_modulo(crop_img, mod=32)
|
|
||||||
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
|
|
||||||
# RGB
|
|
||||||
np_img = self._forward(crop_img, crop_mask, self.device)
|
|
||||||
|
|
||||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
|
|
||||||
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
|
|
||||||
np_img = image
|
|
||||||
# BGR to RGB
|
|
||||||
# np_img = image[:, :, ::-1]
|
|
||||||
|
|
||||||
np_img = np_img[0:origin_height, 0:origin_width, :]
|
|
||||||
np_img = np_img[:, :, ::-1]
|
|
||||||
|
|
||||||
return np_img
|
|
||||||
|
|
||||||
def _forward(self, image, mask, device):
|
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
|
||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
|
||||||
masked_image = (1 - mask) * image
|
masked_image = (1 - mask) * image
|
||||||
|
|
||||||
image = self._norm(image)
|
image = self._norm(image)
|
||||||
@ -371,47 +316,20 @@ class LDM:
|
|||||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||||
|
|
||||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||||
samples_ddim = self.sampler.sample(steps=self.steps,
|
samples_ddim = self.sampler.sample(steps=steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
batch_size=c.shape[0],
|
batch_size=c.shape[0],
|
||||||
shape=shape)
|
shape=shape)
|
||||||
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
|
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
|
||||||
|
|
||||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
inpainted = (1 - mask) * image + mask * predicted_image
|
# inpainted = (1 - mask) * image + mask * predicted_image
|
||||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||||
np_img = inpainted.astype(np.uint8)
|
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
||||||
return np_img
|
return inpainted_image
|
||||||
|
|
||||||
def find_main_content(self, bitmap: np.ndarray):
|
def _norm(self, tensor):
|
||||||
th2 = bitmap[0].astype(np.uint8)
|
return tensor * 2.0 - 1.0
|
||||||
row_sum = th2.sum(1)
|
|
||||||
col_sum = th2.sum(0)
|
|
||||||
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
|
|
||||||
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
|
|
||||||
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
|
|
||||||
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
|
|
||||||
|
|
||||||
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
|
|
||||||
return left, top, right - left, bottom - top
|
|
||||||
|
|
||||||
def box_from_bitmap(self, bitmap):
|
|
||||||
"""
|
|
||||||
bitmap: single map with shape (NUM_CLASSES, H, W),
|
|
||||||
whose values are binarized as {0, 1}
|
|
||||||
"""
|
|
||||||
contours, _ = cv2.findContours(
|
|
||||||
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
|
|
||||||
)
|
|
||||||
|
|
||||||
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
|
|
||||||
num_contours = len(contours)
|
|
||||||
print(f"contours size: {num_contours}")
|
|
||||||
if num_contours != 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# x,y,w,h
|
|
||||||
return cv2.boundingRect(contours[0])
|
|
34
lama_cleaner/model_manager.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from lama_cleaner.model.lama import LaMa
|
||||||
|
from lama_cleaner.model.ldm import LDM
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
LAMA = 'lama'
|
||||||
|
LDM = 'ldm'
|
||||||
|
|
||||||
|
def __init__(self, name: str, device):
|
||||||
|
self.name = name
|
||||||
|
self.device = device
|
||||||
|
self.model = self.init_model(name, device)
|
||||||
|
|
||||||
|
def init_model(self, name: str, device):
|
||||||
|
if name == self.LAMA:
|
||||||
|
model = LaMa(device)
|
||||||
|
elif name == self.LDM:
|
||||||
|
model = LDM(device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Not supported model: {name}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def __call__(self, image, mask, config: Config):
|
||||||
|
return self.model(image, mask, config)
|
||||||
|
|
||||||
|
def switch(self, new_name: str):
|
||||||
|
if new_name == self.name:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.model = self.init_model(new_name, self.device)
|
||||||
|
self.name = new_name
|
||||||
|
except NotImplementedError as e:
|
||||||
|
raise e
|
17
lama_cleaner/schema.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class HDStrategy(str, Enum):
|
||||||
|
ORIGINAL = 'Original'
|
||||||
|
RESIZE = 'Resize'
|
||||||
|
CROP = 'Crop'
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
ldm_steps: int
|
||||||
|
hd_strategy: str
|
||||||
|
hd_strategy_crop_margin: int
|
||||||
|
hd_strategy_crop_trigger_size: int
|
||||||
|
hd_strategy_resize_limit: int
|
0
lama_cleaner/tests/__init__.py
Normal file
BIN
lama_cleaner/tests/image.png
Normal file
After Width: | Height: | Size: 129 KiB |
BIN
lama_cleaner/tests/lama_crop_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/lama_original_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/lama_resize_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_crop_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_original_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_resize_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
Before Width: | Height: | Size: 11 KiB |
BIN
lama_cleaner/tests/mask.png
Normal file
After Width: | Height: | Size: 7.7 KiB |
@ -1,15 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lama_cleaner.helper import boxes_from_mask
|
|
||||||
|
|
||||||
|
|
||||||
def test_boxes_from_mask():
|
|
||||||
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
|
|
||||||
mask = mask[:, :, np.newaxis]
|
|
||||||
mask = (mask / 255).transpose(2, 0, 1)
|
|
||||||
boxes = boxes_from_mask(mask)
|
|
||||||
print(boxes)
|
|
||||||
|
|
||||||
|
|
||||||
test_boxes_from_mask()
|
|
55
lama_cleaner/tests/test_model.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.schema import Config, HDStrategy
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def get_data():
|
||||||
|
img = cv2.imread(str(current_dir / 'image.png'))
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||||
|
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(strategy):
|
||||||
|
return Config(
|
||||||
|
ldm_steps=1,
|
||||||
|
hd_strategy=strategy,
|
||||||
|
hd_strategy_crop_margin=32,
|
||||||
|
hd_strategy_crop_trigger_size=200,
|
||||||
|
hd_strategy_resize_limit=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_equal(model, config, gt_name):
|
||||||
|
img, mask = get_data()
|
||||||
|
res = model(img, mask, config)
|
||||||
|
# cv2.imwrite(gt_name, res,
|
||||||
|
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
|
||||||
|
|
||||||
|
"""
|
||||||
|
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
||||||
|
when the saved image is reloaded, a difference occurs with the original pixel value.
|
||||||
|
If you want to save the original image as it is, save it as PNG or BMP.
|
||||||
|
"""
|
||||||
|
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
|
||||||
|
assert np.array_equal(res, gt)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||||
|
def test_lama(strategy):
|
||||||
|
model = ModelManager(name='lama', device='cpu')
|
||||||
|
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||||
|
def test_ldm(strategy):
|
||||||
|
model = ModelManager(name='ldm', device='cpu')
|
||||||
|
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')
|
94
main.py
@ -2,19 +2,21 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import imghdr
|
import imghdr
|
||||||
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lama_cleaner.lama import LaMa
|
from loguru import logger
|
||||||
from lama_cleaner.ldm import LDM
|
|
||||||
|
|
||||||
from flaskwebgui import FlaskUI
|
from lama_cleaner.model_manager import ModelManager
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
@ -29,7 +31,6 @@ from flask_cors import CORS
|
|||||||
|
|
||||||
from lama_cleaner.helper import (
|
from lama_cleaner.helper import (
|
||||||
load_img,
|
load_img,
|
||||||
norm_img,
|
|
||||||
numpy_to_bytes,
|
numpy_to_bytes,
|
||||||
resize_max_size,
|
resize_max_size,
|
||||||
)
|
)
|
||||||
@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"):
|
|||||||
|
|
||||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
||||||
|
|
||||||
|
|
||||||
|
class InterceptHandler(logging.Handler):
|
||||||
|
def emit(self, record):
|
||||||
|
logger_opt = logger.opt(depth=6, exception=record.exc_info)
|
||||||
|
logger_opt.log(record.levelno, record.getMessage())
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||||
app.config["JSON_AS_ASCII"] = False
|
app.config["JSON_AS_ASCII"] = False
|
||||||
CORS(app)
|
app.logger.addHandler(InterceptHandler())
|
||||||
|
CORS(app, expose_headers=["Content-Disposition"])
|
||||||
|
|
||||||
model = None
|
model: ModelManager = None
|
||||||
device = None
|
device = None
|
||||||
input_image_path: str = None
|
input_image_path: str = None
|
||||||
|
|
||||||
@ -72,24 +81,31 @@ def process():
|
|||||||
original_shape = image.shape
|
original_shape = image.shape
|
||||||
interpolation = cv2.INTER_CUBIC
|
interpolation = cv2.INTER_CUBIC
|
||||||
|
|
||||||
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
form = request.form
|
||||||
|
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
|
||||||
if size_limit == "Original":
|
if size_limit == "Original":
|
||||||
size_limit = max(image.shape)
|
size_limit = max(image.shape)
|
||||||
else:
|
else:
|
||||||
size_limit = int(size_limit)
|
size_limit = int(size_limit)
|
||||||
|
|
||||||
print(f"Origin image shape: {original_shape}")
|
config = Config(
|
||||||
|
ldm_steps=form['ldmSteps'],
|
||||||
|
hd_strategy=form['hdStrategy'],
|
||||||
|
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
|
||||||
|
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
|
||||||
|
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Origin image shape: {original_shape}")
|
||||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||||
print(f"Resized image shape: {image.shape}")
|
logger.info(f"Resized image shape: {image.shape}")
|
||||||
image = norm_img(image)
|
|
||||||
|
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||||
mask = norm_img(mask)
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
res_np_img = model(image, mask)
|
res_np_img = model(image, mask, config)
|
||||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -109,6 +125,19 @@ def process():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/switch_model", methods=["POST"])
|
||||||
|
def switch_model():
|
||||||
|
new_name = request.form.get("name")
|
||||||
|
if new_name == model.name:
|
||||||
|
return "Same model", 200
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.switch(new_name)
|
||||||
|
except NotImplementedError:
|
||||||
|
return f"{new_name} not implemented", 403
|
||||||
|
return f"ok, switch to {new_name}", 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||||
@ -120,7 +149,9 @@ def set_input_photo():
|
|||||||
with open(input_image_path, "rb") as f:
|
with open(input_image_path, "rb") as f:
|
||||||
image_in_bytes = f.read()
|
image_in_bytes = f.read()
|
||||||
return send_file(
|
return send_file(
|
||||||
io.BytesIO(image_in_bytes),
|
input_image_path,
|
||||||
|
as_attachment=True,
|
||||||
|
download_name=Path(input_image_path).name,
|
||||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -135,29 +166,6 @@ def get_args_parser():
|
|||||||
parser.add_argument("--host", default="127.0.0.1")
|
parser.add_argument("--host", default="127.0.0.1")
|
||||||
parser.add_argument("--port", default=8080, type=int)
|
parser.add_argument("--port", default=8080, type=int)
|
||||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||||
parser.add_argument(
|
|
||||||
"--crop-trigger-size",
|
|
||||||
default=[2042, 2042],
|
|
||||||
nargs=2,
|
|
||||||
type=int,
|
|
||||||
help="If image size large then crop-trigger-size, "
|
|
||||||
"crop each area from original image to do inference."
|
|
||||||
"Mainly for performance and memory reasons"
|
|
||||||
"Only for lama",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--crop-margin",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Margin around bounding box of painted stroke when crop mode triggered",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ldm-steps",
|
|
||||||
default=50,
|
|
||||||
type=int,
|
|
||||||
help="Steps for DDIM sampling process."
|
|
||||||
"The larger the value, the better the result, but it will be more time-consuming",
|
|
||||||
)
|
|
||||||
parser.add_argument("--device", default="cuda", type=str)
|
parser.add_argument("--device", default="cuda", type=str)
|
||||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -188,19 +196,11 @@ def main():
|
|||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
input_image_path = args.input
|
input_image_path = args.input
|
||||||
|
|
||||||
if args.model == "lama":
|
model = ModelManager(name=args.model, device=device)
|
||||||
model = LaMa(
|
|
||||||
crop_trigger_size=args.crop_trigger_size,
|
|
||||||
crop_margin=args.crop_margin,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif args.model == "ldm":
|
|
||||||
model = LDM(device, steps=args.ldm_steps)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Not supported model: {args.model}")
|
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
app_width, app_height = args.gui_size
|
app_width, app_height = args.gui_size
|
||||||
|
from flaskwebgui import FlaskUI
|
||||||
ui = FlaskUI(app, width=app_width, height=app_height)
|
ui = FlaskUI(app, width=app_width, height=app_height)
|
||||||
ui.run()
|
ui.run()
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
torch>=1.8.2
|
torch>=1.8.2
|
||||||
opencv-python
|
opencv-python
|
||||||
flask_cors
|
flask_cors
|
||||||
flask
|
flask==2.1.1
|
||||||
flaskwebgui
|
flaskwebgui
|
||||||
tqdm
|
tqdm
|
||||||
|
pydantic
|
||||||
|
loguru
|
||||||
|
pytest
|
||||||
|