big update

This commit is contained in:
Sanster 2022-04-16 00:11:51 +08:00
parent 2b031603ed
commit 205286a414
40 changed files with 539 additions and 376 deletions

View File

@ -1,10 +1,12 @@
import { Setting } from '../store/Atoms'
import { dataURItoBlob } from '../utils' import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint` export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
export default async function inpaint( export default async function inpaint(
imageFile: File, imageFile: File,
maskBase64: string, maskBase64: string,
settings: Setting,
sizeLimit?: string sizeLimit?: string
) { ) {
// 1080, 2000, Original // 1080, 2000, Original
@ -13,13 +15,22 @@ export default async function inpaint(
const mask = dataURItoBlob(maskBase64) const mask = dataURItoBlob(maskBase64)
fd.append('mask', mask) fd.append('mask', mask)
fd.append('ldmSteps', settings.ldmSteps.toString())
fd.append('hdStrategy', settings.hdStrategy)
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
fd.append(
'hdStrategyCropTrigerSize',
settings.hdStrategyCropTrigerSize.toString()
)
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
if (sizeLimit === undefined) { if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080') fd.append('sizeLimit', '1080')
} else { } else {
fd.append('sizeLimit', sizeLimit) fd.append('sizeLimit', sizeLimit)
} }
const res = await fetch(API_ENDPOINT, { const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: 'POST', method: 'POST',
body: fd, body: fd,
}).then(async r => { }).then(async r => {
@ -28,3 +39,12 @@ export default async function inpaint(
return URL.createObjectURL(res) return URL.createObjectURL(res)
} }
export function switchModel(name: string) {
const fd = new FormData()
fd.append('name', name)
return fetch(`${API_ENDPOINT}/switch_model`, {
method: 'POST',
body: fd,
})
}

View File

@ -15,12 +15,14 @@ import {
TransformComponent, TransformComponent,
TransformWrapper, TransformWrapper,
} from 'react-zoom-pan-pinch' } from 'react-zoom-pan-pinch'
import { useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint from '../../adapters/inpainting' import inpaint from '../../adapters/inpainting'
import Button from '../shared/Button' import Button from '../shared/Button'
import Slider from './Slider' import Slider from './Slider'
import SizeSelector from './SizeSelector' import SizeSelector from './SizeSelector'
import { downloadImage, loadImage, useImage } from '../../utils' import { downloadImage, loadImage, useImage } from '../../utils'
import { settingState } from '../../store/Atoms'
const TOOLBAR_SIZE = 200 const TOOLBAR_SIZE = 200
const BRUSH_COLOR = '#ffcc00bb' const BRUSH_COLOR = '#ffcc00bb'
@ -57,6 +59,7 @@ function drawLines(
export default function Editor(props: EditorProps) { export default function Editor(props: EditorProps) {
const { file } = props const { file } = props
const settings = useRecoilValue(settingState)
const [brushSize, setBrushSize] = useState(40) const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file) const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([]) const [renders, setRenders] = useState<HTMLImageElement[]>([])
@ -125,6 +128,7 @@ export default function Editor(props: EditorProps) {
const res = await inpaint( const res = await inpaint(
file, file,
maskCanvas.toDataURL(), maskCanvas.toDataURL(),
settings,
sizeLimit.toString() sizeLimit.toString()
) )
if (!res) { if (!res) {
@ -157,6 +161,7 @@ export default function Editor(props: EditorProps) {
renders, renders,
sizeLimit, sizeLimit,
historyLineCount, historyLineCount,
settings,
]) ])
const hadDrawSomething = () => { const hadDrawSomething = () => {

View File

@ -6,7 +6,7 @@ import Button from '../shared/Button'
import Shortcuts from '../Shortcuts/Shortcuts' import Shortcuts from '../Shortcuts/Shortcuts'
import useResolution from '../../hooks/useResolution' import useResolution from '../../hooks/useResolution'
import { ThemeChanger } from './ThemeChanger' import { ThemeChanger } from './ThemeChanger'
import SettingIcon from '../Setting/SettingIcon' import SettingIcon from '../Settings/SettingIcon'
const Header = () => { const Header = () => {
const [file, setFile] = useRecoilState(fileState) const [file, setFile] = useRecoilState(fileState)

View File

@ -1,50 +1,16 @@
import React, { ReactNode } from 'react' import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms' import { settingState } from '../../store/Atoms'
import NumberInput from '../shared/NumberInput'
import Selector from '../shared/Selector' import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock' import SettingBlock from './SettingBlock'
export enum HDStrategy { export enum HDStrategy {
ORIGINAL = 'Original', ORIGINAL = 'Original',
REISIZE = 'Resize', RESIZE = 'Resize',
CROP = 'Crop', CROP = 'Crop',
} }
interface PixelSizeInputProps {
title: string
value: string
onValue: (val: string) => void
}
function PixelSizeInputSetting(props: PixelSizeInputProps) {
const { title, value, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
<span>pixel</span>
</div>
}
/>
)
}
function HDSettingBlock() { function HDSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState) const [setting, setSettingState] = useRecoilState(settingState)
@ -84,7 +50,7 @@ function HDSettingBlock() {
tabIndex={0} tabIndex={0}
role="button" role="button"
className="inline-tip" className="inline-tip"
onClick={() => onStrategyChange(HDStrategy.REISIZE)} onClick={() => onStrategyChange(HDStrategy.RESIZE)}
> >
Resize Strategy Resize Strategy
</div>{' '} </div>{' '}
@ -100,9 +66,10 @@ function HDSettingBlock() {
Resize the longer side of the image to a specific size(keep ratio), Resize the longer side of the image to a specific size(keep ratio),
then do inpainting on the resized image. then do inpainting on the resized image.
</div> </div>
<PixelSizeInputSetting <NumberInputSetting
title="Size limit" title="Size limit"
value={`${setting.hdStrategyResizeLimit}`} value={`${setting.hdStrategyResizeLimit}`}
suffix="pixel"
onValue={onResizeLimitChange} onValue={onResizeLimitChange}
/> />
</div> </div>
@ -117,14 +84,16 @@ function HDSettingBlock() {
the result back. Mainly for performance and memory reasons on high the result back. Mainly for performance and memory reasons on high
resolution image. resolution image.
</div> </div>
<PixelSizeInputSetting <NumberInputSetting
title="Trigger size" title="Trigger size"
value={`${setting.hdStrategyCropTrigerSize}`} value={`${setting.hdStrategyCropTrigerSize}`}
suffix="pixel"
onValue={onCropTriggerSizeChange} onValue={onCropTriggerSizeChange}
/> />
<PixelSizeInputSetting <NumberInputSetting
title="Crop margin" title="Crop margin"
value={`${setting.hdStrategyCropMargin}`} value={`${setting.hdStrategyCropMargin}`}
suffix="pixel"
onValue={onCropMarginChange} onValue={onCropMarginChange}
/> />
</div> </div>
@ -137,7 +106,7 @@ function HDSettingBlock() {
return renderOriginalOptionDesc() return renderOriginalOptionDesc()
case HDStrategy.CROP: case HDStrategy.CROP:
return renderCropOptionDesc() return renderCropOptionDesc()
case HDStrategy.REISIZE: case HDStrategy.RESIZE:
return renderResizeOptionDesc() return renderResizeOptionDesc()
default: default:
return renderOriginalOptionDesc() return renderOriginalOptionDesc()

View File

@ -2,11 +2,12 @@ import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms' import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector' import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock' import SettingBlock from './SettingBlock'
export enum AIModel { export enum AIModel {
LAMA = 'LaMa', LAMA = 'lama',
LDM = 'LDM', LDM = 'ldm',
} }
function ModelSettingBlock() { function ModelSettingBlock() {
@ -24,7 +25,7 @@ function ModelSettingBlock() {
githubUrl: string githubUrl: string
) => { ) => {
return ( return (
<div style={{ display: 'flex', flexDirection: 'column' }}> <div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
<a <a
className="model-desc-link" className="model-desc-link"
href={paperUrl} href={paperUrl}
@ -34,14 +35,11 @@ function ModelSettingBlock() {
{name} {name}
</a> </a>
<br />
<a <a
className="model-desc-link" className="model-desc-link"
href={githubUrl} href={githubUrl}
target="_blank" target="_blank"
rel="noreferrer noopener" rel="noreferrer noopener"
style={{ marginTop: '8px' }}
> >
{githubUrl} {githubUrl}
</a> </a>
@ -49,6 +47,28 @@ function ModelSettingBlock() {
) )
} }
const renderLDMModelDesc = () => {
return (
<div>
{renderModelDesc(
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)}
<NumberInputSetting
title="Steps"
value={`${setting.ldmSteps}`}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, ldmSteps: val }
})
}}
/>
</div>
)
}
const renderOptionDesc = (): ReactNode => { const renderOptionDesc = (): ReactNode => {
switch (setting.model) { switch (setting.model) {
case AIModel.LAMA: case AIModel.LAMA:
@ -58,11 +78,7 @@ function ModelSettingBlock() {
'https://github.com/saic-mdal/lama' 'https://github.com/saic-mdal/lama'
) )
case AIModel.LDM: case AIModel.LDM:
return renderModelDesc( return renderLDMModelDesc()
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)
default: default:
return <></> return <></>
} }

View File

@ -0,0 +1,40 @@
import React from 'react'
import NumberInput from '../shared/NumberInput'
import SettingBlock from './SettingBlock'
interface NumberInputSettingProps {
title: string
value: string
suffix?: string
onValue: (val: string) => void
}
function NumberInputSetting(props: NumberInputSettingProps) {
const { title, value, suffix, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
{suffix && <span>{suffix}</span>}
</div>
}
/>
)
}
export default NumberInputSetting

View File

@ -1,14 +1,26 @@
import React, { ReactNode } from 'react' import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import { Switch, SwitchThumb } from '../shared/Switch' import { Switch, SwitchThumb } from '../shared/Switch'
import SettingBlock from './SettingBlock' import SettingBlock from './SettingBlock'
function SavePathSettingBlock() { function SavePathSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
const onCheckChange = (checked: boolean) => {
setSettingState(old => {
return { ...old, saveImageBesideOrigin: checked }
})
}
return ( return (
<SettingBlock <SettingBlock
title="Download image beside origin image" title="Download image beside origin image"
input={ input={
<Switch defaultChecked> <Switch
checked={setting.saveImageBesideOrigin}
onCheckedChange={onCheckChange}
>
<SwitchThumb /> <SwitchThumb />
</Switch> </Switch>
} }

View File

@ -7,7 +7,7 @@
margin-top: 12px; margin-top: 12px;
border: 1px solid var(--border-color); border: 1px solid var(--border-color);
border-radius: 0.3rem; border-radius: 0.3rem;
padding: 2rem; padding: 1rem;
.sub-setting-block { .sub-setting-block {
margin-top: 8px; margin-top: 8px;

View File

@ -8,7 +8,6 @@
background-color: var(--modal-bg); background-color: var(--modal-bg);
color: var(--modal-text-color); color: var(--modal-text-color);
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2); box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
min-height: 600px;
width: 700px; width: 700px;
@include mobile { @include mobile {

View File

@ -1,11 +1,11 @@
import React from 'react' import React from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { switchModel } from '../../adapters/inpainting'
import { settingState } from '../../store/Atoms' import { settingState } from '../../store/Atoms'
import Modal from '../shared/Modal' import Modal from '../shared/Modal'
import HDSettingBlock from './HDSettingBlock' import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock' import ModelSettingBlock from './ModelSettingBlock'
import SavePathSettingBlock from './SavePathSettingBlock'
export default function SettingModal() { export default function SettingModal() {
const [setting, setSettingState] = useRecoilState(settingState) const [setting, setSettingState] = useRecoilState(settingState)
@ -14,6 +14,8 @@ export default function SettingModal() {
setSettingState(old => { setSettingState(old => {
return { ...old, show: false } return { ...old, show: false }
}) })
switchModel(setting.model)
} }
return ( return (
@ -23,7 +25,9 @@ export default function SettingModal() {
className="modal-setting" className="modal-setting"
show={setting.show} show={setting.show}
> >
<SavePathSettingBlock /> {/* It's not possible because this poses a security risk */}
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
{/* <SavePathSettingBlock /> */}
<ModelSettingBlock /> <ModelSettingBlock />
<HDSettingBlock /> <HDSettingBlock />
</Modal> </Modal>

View File

@ -1,9 +1,7 @@
import React from 'react' import React from 'react'
import { useRecoilValue } from 'recoil'
import Editor from './Editor/Editor' import Editor from './Editor/Editor'
import { shortcutsState } from '../store/Atoms'
import ShortcutsModal from './Shortcuts/ShortcutsModal' import ShortcutsModal from './Shortcuts/ShortcutsModal'
import SettingModal from './Setting/SettingModal' import SettingModal from './Settings/SettingsModal'
interface WorkspaceProps { interface WorkspaceProps {
file: File file: File

View File

@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
const ref = useRef(null) const ref = useRef(null)
useClickAway(ref, () => { useClickAway(ref, () => {
if (show) {
onClose?.() onClose?.()
}
}) })
useKeyPressEvent('Escape', e => { useKeyPressEvent('Escape', e => {
if (show) {
onClose?.() onClose?.()
}
}) })
return ( return (

View File

@ -1,6 +1,6 @@
import { atom } from 'recoil' import { atom } from 'recoil'
import { HDStrategy } from '../components/Setting/HDSettingBlock' import { HDStrategy } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Setting/ModelSettingBlock' import { AIModel } from '../components/Settings/ModelSettingBlock'
export const fileState = atom<File | undefined>({ export const fileState = atom<File | undefined>({
key: 'fileState', key: 'fileState',
@ -16,10 +16,15 @@ export interface Setting {
show: boolean show: boolean
saveImageBesideOrigin: boolean saveImageBesideOrigin: boolean
model: AIModel model: AIModel
// For LaMa
hdStrategy: HDStrategy hdStrategy: HDStrategy
hdStrategyResizeLimit: number hdStrategyResizeLimit: number
hdStrategyCropTrigerSize: number hdStrategyCropTrigerSize: number
hdStrategyCropMargin: number hdStrategyCropMargin: number
// For LDM
ldmSteps: number
} }
export const settingState = atom<Setting>({ export const settingState = atom<Setting>({
@ -28,9 +33,10 @@ export const settingState = atom<Setting>({
show: false, show: false,
saveImageBesideOrigin: false, saveImageBesideOrigin: false,
model: AIModel.LAMA, model: AIModel.LAMA,
hdStrategy: HDStrategy.ORIGINAL, hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048, hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048, hdStrategyCropTrigerSize: 2048,
hdStrategyCropMargin: 128, hdStrategyCropMargin: 128,
ldmSteps: 50,
}, },
}) })

View File

@ -11,7 +11,7 @@
@use '../components/Header/Header'; @use '../components/Header/Header';
@use '../components/Header/ThemeChanger'; @use '../components/Header/ThemeChanger';
@use '../components/Shortcuts/Shortcuts'; @use '../components/Shortcuts/Shortcuts';
@use '../components/Setting/Setting.scss'; @use '../components/Settings/Settings.scss';
// Shared // Shared
@use '../components/FileSelect/FileSelect'; @use '../components/FileSelect/FileSelect';

View File

@ -31,7 +31,11 @@ def ceil_modulo(x, mod):
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(f".{ext}", image_numpy)[1] data = cv2.imencode(f".{ext}", image_numpy,
[
int(cv2.IMWRITE_JPEG_QUALITY), 100,
int(cv2.IMWRITE_PNG_COMPRESSION), 0
])[1]
image_bytes = data.tobytes() image_bytes = data.tobytes()
return image_bytes return image_bytes
@ -74,13 +78,24 @@ def resize_max_size(
return np_img return np_img
def pad_img_to_modulo(img, mod): def pad_img_to_modulo(img: np.ndarray, mod: int):
channels, height, width = img.shape """
Args:
img: [H, W, C]
mod:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod) out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod) out_width = ceil_modulo(width, mod)
return np.pad( return np.pad(
img, img,
((0, 0), (0, out_height - height), (0, out_width - width)), ((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric", mode="symmetric",
) )
@ -88,15 +103,13 @@ def pad_img_to_modulo(img, mod):
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
""" """
Args: Args:
mask: (1, h, w) 0~1 mask: (h, w, 1) 0~255
Returns: Returns:
""" """
height, width = mask.shape[1:] height, width = mask.shape[:2]
_, thresh = cv2.threshold( _, thresh = cv2.threshold(mask, 127, 255, 0)
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = [] boxes = []

View File

@ -1,121 +0,0 @@
import os
from typing import List
import cv2
import torch
import numpy as np
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa:
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
"""
Args:
crop_trigger_size: h, w
crop_margin:
device:
"""
self.crop_trigger_size = crop_trigger_size
self.crop_margin = crop_margin
self.device = device
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
print(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
area = image.shape[1] * image.shape[2]
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
return self._run(image, mask)
print("Trigger crop image")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box)
crop_result.append((crop_image, crop_box))
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
image[y1:y2, x1:x2, :] = crop_image
return image
def _run_box(self, image, mask, box):
"""
Args:
image: [C, H, W] RGB
mask: [1, H, W]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[1:]
w = box_w + self.crop_margin * 2
h = box_h + self.crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[:, t:b, l:r]
crop_mask = mask[:, t:b, l:r]
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._run(crop_img, crop_mask), [l, t, r, b]
def _run(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
device = self.device
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

122
lama_cleaner/model/base.py Normal file
View File

@ -0,0 +1,122 @@
import abc
import cv2
import torch
from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
from lama_cleaner.schema import Config, HDStrategy
class InpaintModel:
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
self.device = device
self.init_model(device)
@abc.abstractmethod
def init_model(self, device):
...
@abc.abstractmethod
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
result = self.forward(padd_image, padd_mask, config)
result = result[0:origin_height, 0:origin_width, :]
original_pixel_indices = mask != 255
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
image: [H, W, C] RGB, not normalized
mask: [H, W]
return: BGR IMAGE
"""
inpaint_result = None
logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info(f"Run crop strategy")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box, config)
crop_result.append((crop_image, crop_box))
inpaint_result = image[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC)
original_pixel_indices = mask != 255
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
def _run_box(self, image, mask, box, config: Config):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[:2]
w = box_w + config.hd_strategy_crop_margin * 2
h = box_h + config.hd_strategy_crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]

View File

@ -0,0 +1,64 @@
import os
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa(InpaintModel):
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
super().__init__(device)
self.device = device
def init_model(self, device):
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
logger.info(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

View File

@ -2,13 +2,16 @@ import os
import numpy as np import numpy as np
import torch import torch
from loguru import logger
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
torch.manual_seed(42) torch.manual_seed(42)
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
import cv2 from lama_cleaner.helper import download_model, norm_img
from lama_cleaner.helper import pad_img_to_modulo, download_model from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
timestep_embedding timestep_embedding
LDM_ENCODE_MODEL_URL = os.environ.get( LDM_ENCODE_MODEL_URL = os.environ.get(
@ -217,7 +220,7 @@ class DDIMSampler(object):
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps") logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
@ -268,97 +271,39 @@ def load_jit_model(url, device):
return model return model
class LDM: class LDM(InpaintModel):
def __init__(self, device, steps=50): pad_mod = 32
def __init__(self, device):
super().__init__(device)
self.device = device self.device = device
def init_model(self, device):
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
model = LatentDiffusion(self.diffusion_model, device) model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model) self.sampler = DDIMSampler(model)
self.steps = steps
def _norm(self, tensor): def forward(self, image, mask, config: Config):
return tensor * 2.0 - 1.0
@torch.no_grad()
def __call__(self, image, mask):
""" """
image: [C, H, W] RGB image: [H, W, C] RGB
mask: [1, H, W] mask: [H, W, 1]
return: BGR IMAGE return: BGR IMAGE
""" """
# image [1,3,512,512] float32 # image [1,3,512,512] float32
# mask: [1,1,512,512] float32 # mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32 # masked_image: [1,3,512,512] float32
origin_height, origin_width = image.shape[1:] steps = config.ldm_steps
image = pad_img_to_modulo(image, mod=32) image = norm_img(image)
mask = pad_img_to_modulo(mask, mod=32) mask = norm_img(mask)
padded_height, padded_width = image.shape[1:]
mask[mask < 0.5] = 0 mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1 mask[mask >= 0.5] = 1
# crop 512 x 512 image = torch.from_numpy(image).unsqueeze(0).to(self.device)
if padded_width <= 512 or padded_height <= 512: mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
np_img = self._forward(image, mask, self.device)
else:
print("Try to zoom in")
# zoom in
# x,y,w,h
# box = self.box_from_bitmap(mask)
box = self.find_main_content(mask)
if box is None:
print("No bbox found")
np_img = self._forward(image, mask, self.device)
else:
print(f"box: {box}")
box_x, box_y, box_w, box_h = box
cx = box_x + box_w // 2
cy = box_y + box_h // 2
# w = max(512, box_w)
# h = max(512, box_h)
w = box_w + 512
h = box_h + 512
left = max(cx - w // 2, 0)
top = max(cy - h // 2, 0)
right = min(cx + w // 2, origin_width)
bottom = min(cy + h // 2, origin_height)
x = left
y = top
w = right - left
h = bottom - top
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
print(f"Apply zoom in size width x height: {crop_img.shape}")
crop_img_height, crop_img_width = crop_img.shape[1:]
crop_img = pad_img_to_modulo(crop_img, mod=32)
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
# RGB
np_img = self._forward(crop_img, crop_mask, self.device)
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
np_img = image
# BGR to RGB
# np_img = image[:, :, ::-1]
np_img = np_img[0:origin_height, 0:origin_width, :]
np_img = np_img[:, :, ::-1]
return np_img
def _forward(self, image, mask, device):
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
masked_image = (1 - mask) * image masked_image = (1 - mask) * image
image = self._norm(image) image = self._norm(image)
@ -371,47 +316,20 @@ class LDM:
c = torch.cat((c, cc), dim=1) # 1,4,128,128 c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:] shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample(steps=self.steps, samples_ddim = self.sampler.sample(steps=steps,
conditioning=c, conditioning=c,
batch_size=c.shape[0], batch_size=c.shape[0],
shape=shape) shape=shape)
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32 x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1 - mask) * image + mask * predicted_image # inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
np_img = inpainted.astype(np.uint8) inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
return np_img return inpainted_image
def find_main_content(self, bitmap: np.ndarray): def _norm(self, tensor):
th2 = bitmap[0].astype(np.uint8) return tensor * 2.0 - 1.0
row_sum = th2.sum(1)
col_sum = th2.sum(0)
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
return left, top, right - left, bottom - top
def box_from_bitmap(self, bitmap):
"""
bitmap: single map with shape (NUM_CLASSES, H, W),
whose values are binarized as {0, 1}
"""
contours, _ = cv2.findContours(
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
)
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
num_contours = len(contours)
print(f"contours size: {num_contours}")
if num_contours != 1:
return None
# x,y,w,h
return cv2.boundingRect(contours[0])

View File

@ -0,0 +1,34 @@
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.schema import Config
class ModelManager:
LAMA = 'lama'
LDM = 'ldm'
def __init__(self, name: str, device):
self.name = name
self.device = device
self.model = self.init_model(name, device)
def init_model(self, name: str, device):
if name == self.LAMA:
model = LaMa(device)
elif name == self.LDM:
model = LDM(device)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)
def switch(self, new_name: str):
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.name = new_name
except NotImplementedError as e:
raise e

17
lama_cleaner/schema.py Normal file
View File

@ -0,0 +1,17 @@
from enum import Enum
from pydantic import BaseModel
class HDStrategy(str, Enum):
ORIGINAL = 'Original'
RESIZE = 'Resize'
CROP = 'Crop'
class Config(BaseModel):
ldm_steps: int
hd_strategy: str
hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int
hd_strategy_resize_limit: int

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

BIN
lama_cleaner/tests/mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@ -1,15 +0,0 @@
import cv2
import numpy as np
from lama_cleaner.helper import boxes_from_mask
def test_boxes_from_mask():
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
mask = mask[:, :, np.newaxis]
mask = (mask / 255).transpose(2, 0, 1)
boxes = boxes_from_mask(mask)
print(boxes)
test_boxes_from_mask()

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import cv2
import numpy as np
import pytest
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
def get_data():
img = cv2.imread(str(current_dir / 'image.png'))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
return img, mask
def get_config(strategy):
return Config(
ldm_steps=1,
hd_strategy=strategy,
hd_strategy_crop_margin=32,
hd_strategy_crop_trigger_size=200,
hd_strategy_resize_limit=200,
)
def assert_equal(model, config, gt_name):
img, mask = get_data()
res = model(img, mask, config)
# cv2.imwrite(gt_name, res,
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
"""
Note that JPEG is lossy compression, so even if it is the highest quality 100,
when the saved image is reloaded, a difference occurs with the original pixel value.
If you want to save the original image as it is, save it as PNG or BMP.
"""
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
assert np.array_equal(res, gt)
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_lama(strategy):
model = ModelManager(name='lama', device='cpu')
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_ldm(strategy):
model = ModelManager(name='ldm', device='cpu')
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')

94
main.py
View File

@ -2,19 +2,21 @@
import argparse import argparse
import io import io
import logging
import multiprocessing import multiprocessing
import os import os
import time import time
import imghdr import imghdr
from pathlib import Path
from typing import Union from typing import Union
import cv2 import cv2
import torch import torch
import numpy as np import numpy as np
from lama_cleaner.lama import LaMa from loguru import logger
from lama_cleaner.ldm import LDM
from flaskwebgui import FlaskUI from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try: try:
torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_cpu(False)
@ -29,7 +31,6 @@ from flask_cors import CORS
from lama_cleaner.helper import ( from lama_cleaner.helper import (
load_img, load_img,
norm_img,
numpy_to_bytes, numpy_to_bytes,
resize_max_size, resize_max_size,
) )
@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"):
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
class InterceptHandler(logging.Handler):
def emit(self, record):
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(record.levelno, record.getMessage())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
CORS(app) app.logger.addHandler(InterceptHandler())
CORS(app, expose_headers=["Content-Disposition"])
model = None model: ModelManager = None
device = None device = None
input_image_path: str = None input_image_path: str = None
@ -72,24 +81,31 @@ def process():
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080") form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original": if size_limit == "Original":
size_limit = max(image.shape) size_limit = max(image.shape)
else: else:
size_limit = int(size_limit) size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}") config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image.shape}") logger.info(f"Resized image shape: {image.shape}")
image = norm_img(image)
mask, _ = load_img(input["mask"].read(), gray=True) mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask)
start = time.time() start = time.time()
res_np_img = model(image, mask) res_np_img = model(image, mask, config)
print(f"process time: {(time.time() - start) * 1000}ms") logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -109,6 +125,19 @@ def process():
) )
@app.route("/switch_model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/") @app.route("/")
def index(): def index():
return send_file(os.path.join(BUILD_DIR, "index.html")) return send_file(os.path.join(BUILD_DIR, "index.html"))
@ -120,7 +149,9 @@ def set_input_photo():
with open(input_image_path, "rb") as f: with open(input_image_path, "rb") as f:
image_in_bytes = f.read() image_in_bytes = f.read()
return send_file( return send_file(
io.BytesIO(image_in_bytes), input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}", mimetype=f"image/{get_image_ext(image_in_bytes)}",
) )
else: else:
@ -135,29 +166,6 @@ def get_args_parser():
parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument(
"--crop-trigger-size",
default=[2042, 2042],
nargs=2,
type=int,
help="If image size large then crop-trigger-size, "
"crop each area from original image to do inference."
"Mainly for performance and memory reasons"
"Only for lama",
)
parser.add_argument(
"--crop-margin",
type=int,
default=256,
help="Margin around bounding box of painted stroke when crop mode triggered",
)
parser.add_argument(
"--ldm-steps",
default=50,
type=int,
help="Steps for DDIM sampling process."
"The larger the value, the better the result, but it will be more time-consuming",
)
parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument( parser.add_argument(
@ -188,19 +196,11 @@ def main():
device = torch.device(args.device) device = torch.device(args.device)
input_image_path = args.input input_image_path = args.input
if args.model == "lama": model = ModelManager(name=args.model, device=device)
model = LaMa(
crop_trigger_size=args.crop_trigger_size,
crop_margin=args.crop_margin,
device=device,
)
elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps)
else:
raise NotImplementedError(f"Not supported model: {args.model}")
if args.gui: if args.gui:
app_width, app_height = args.gui_size app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
ui = FlaskUI(app, width=app_width, height=app_height) ui = FlaskUI(app, width=app_width, height=app_height)
ui.run() ui.run()
else: else:

View File

@ -1,6 +1,9 @@
torch>=1.8.2 torch>=1.8.2
opencv-python opencv-python
flask_cors flask_cors
flask flask==2.1.1
flaskwebgui flaskwebgui
tqdm tqdm
pydantic
loguru
pytest