big update
@ -1,10 +1,12 @@
|
||||
import { Setting } from '../store/Atoms'
|
||||
import { dataURItoBlob } from '../utils'
|
||||
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
|
||||
export default async function inpaint(
|
||||
imageFile: File,
|
||||
maskBase64: string,
|
||||
settings: Setting,
|
||||
sizeLimit?: string
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
@ -13,13 +15,22 @@ export default async function inpaint(
|
||||
const mask = dataURItoBlob(maskBase64)
|
||||
fd.append('mask', mask)
|
||||
|
||||
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||
fd.append('hdStrategy', settings.hdStrategy)
|
||||
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
||||
fd.append(
|
||||
'hdStrategyCropTrigerSize',
|
||||
settings.hdStrategyCropTrigerSize.toString()
|
||||
)
|
||||
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
|
||||
|
||||
if (sizeLimit === undefined) {
|
||||
fd.append('sizeLimit', '1080')
|
||||
} else {
|
||||
fd.append('sizeLimit', sizeLimit)
|
||||
}
|
||||
|
||||
const res = await fetch(API_ENDPOINT, {
|
||||
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
}).then(async r => {
|
||||
@ -28,3 +39,12 @@ export default async function inpaint(
|
||||
|
||||
return URL.createObjectURL(res)
|
||||
}
|
||||
|
||||
export function switchModel(name: string) {
|
||||
const fd = new FormData()
|
||||
fd.append('name', name)
|
||||
return fetch(`${API_ENDPOINT}/switch_model`, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
})
|
||||
}
|
||||
|
@ -15,12 +15,14 @@ import {
|
||||
TransformComponent,
|
||||
TransformWrapper,
|
||||
} from 'react-zoom-pan-pinch'
|
||||
import { useRecoilValue } from 'recoil'
|
||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||
import inpaint from '../../adapters/inpainting'
|
||||
import Button from '../shared/Button'
|
||||
import Slider from './Slider'
|
||||
import SizeSelector from './SizeSelector'
|
||||
import { downloadImage, loadImage, useImage } from '../../utils'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
|
||||
const TOOLBAR_SIZE = 200
|
||||
const BRUSH_COLOR = '#ffcc00bb'
|
||||
@ -57,6 +59,7 @@ function drawLines(
|
||||
|
||||
export default function Editor(props: EditorProps) {
|
||||
const { file } = props
|
||||
const settings = useRecoilValue(settingState)
|
||||
const [brushSize, setBrushSize] = useState(40)
|
||||
const [original, isOriginalLoaded] = useImage(file)
|
||||
const [renders, setRenders] = useState<HTMLImageElement[]>([])
|
||||
@ -125,6 +128,7 @@ export default function Editor(props: EditorProps) {
|
||||
const res = await inpaint(
|
||||
file,
|
||||
maskCanvas.toDataURL(),
|
||||
settings,
|
||||
sizeLimit.toString()
|
||||
)
|
||||
if (!res) {
|
||||
@ -157,6 +161,7 @@ export default function Editor(props: EditorProps) {
|
||||
renders,
|
||||
sizeLimit,
|
||||
historyLineCount,
|
||||
settings,
|
||||
])
|
||||
|
||||
const hadDrawSomething = () => {
|
||||
|
@ -6,7 +6,7 @@ import Button from '../shared/Button'
|
||||
import Shortcuts from '../Shortcuts/Shortcuts'
|
||||
import useResolution from '../../hooks/useResolution'
|
||||
import { ThemeChanger } from './ThemeChanger'
|
||||
import SettingIcon from '../Setting/SettingIcon'
|
||||
import SettingIcon from '../Settings/SettingIcon'
|
||||
|
||||
const Header = () => {
|
||||
const [file, setFile] = useRecoilState(fileState)
|
||||
|
@ -1,50 +1,16 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import NumberInput from '../shared/NumberInput'
|
||||
import Selector from '../shared/Selector'
|
||||
import NumberInputSetting from './NumberInputSetting'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
export enum HDStrategy {
|
||||
ORIGINAL = 'Original',
|
||||
REISIZE = 'Resize',
|
||||
RESIZE = 'Resize',
|
||||
CROP = 'Crop',
|
||||
}
|
||||
|
||||
interface PixelSizeInputProps {
|
||||
title: string
|
||||
value: string
|
||||
onValue: (val: string) => void
|
||||
}
|
||||
|
||||
function PixelSizeInputSetting(props: PixelSizeInputProps) {
|
||||
const { title, value, onValue } = props
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
className="sub-setting-block"
|
||||
title={title}
|
||||
input={
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
gap: '8px',
|
||||
}}
|
||||
>
|
||||
<NumberInput
|
||||
style={{ width: '80px' }}
|
||||
value={`${value}`}
|
||||
onValue={onValue}
|
||||
/>
|
||||
<span>pixel</span>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function HDSettingBlock() {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
@ -84,7 +50,7 @@ function HDSettingBlock() {
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="inline-tip"
|
||||
onClick={() => onStrategyChange(HDStrategy.REISIZE)}
|
||||
onClick={() => onStrategyChange(HDStrategy.RESIZE)}
|
||||
>
|
||||
Resize Strategy
|
||||
</div>{' '}
|
||||
@ -100,9 +66,10 @@ function HDSettingBlock() {
|
||||
Resize the longer side of the image to a specific size(keep ratio),
|
||||
then do inpainting on the resized image.
|
||||
</div>
|
||||
<PixelSizeInputSetting
|
||||
<NumberInputSetting
|
||||
title="Size limit"
|
||||
value={`${setting.hdStrategyResizeLimit}`}
|
||||
suffix="pixel"
|
||||
onValue={onResizeLimitChange}
|
||||
/>
|
||||
</div>
|
||||
@ -117,14 +84,16 @@ function HDSettingBlock() {
|
||||
the result back. Mainly for performance and memory reasons on high
|
||||
resolution image.
|
||||
</div>
|
||||
<PixelSizeInputSetting
|
||||
<NumberInputSetting
|
||||
title="Trigger size"
|
||||
value={`${setting.hdStrategyCropTrigerSize}`}
|
||||
suffix="pixel"
|
||||
onValue={onCropTriggerSizeChange}
|
||||
/>
|
||||
<PixelSizeInputSetting
|
||||
<NumberInputSetting
|
||||
title="Crop margin"
|
||||
value={`${setting.hdStrategyCropMargin}`}
|
||||
suffix="pixel"
|
||||
onValue={onCropMarginChange}
|
||||
/>
|
||||
</div>
|
||||
@ -137,7 +106,7 @@ function HDSettingBlock() {
|
||||
return renderOriginalOptionDesc()
|
||||
case HDStrategy.CROP:
|
||||
return renderCropOptionDesc()
|
||||
case HDStrategy.REISIZE:
|
||||
case HDStrategy.RESIZE:
|
||||
return renderResizeOptionDesc()
|
||||
default:
|
||||
return renderOriginalOptionDesc()
|
@ -2,11 +2,12 @@ import React, { ReactNode } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Selector from '../shared/Selector'
|
||||
import NumberInputSetting from './NumberInputSetting'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
export enum AIModel {
|
||||
LAMA = 'LaMa',
|
||||
LDM = 'LDM',
|
||||
LAMA = 'lama',
|
||||
LDM = 'ldm',
|
||||
}
|
||||
|
||||
function ModelSettingBlock() {
|
||||
@ -24,7 +25,7 @@ function ModelSettingBlock() {
|
||||
githubUrl: string
|
||||
) => {
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column' }}>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
|
||||
<a
|
||||
className="model-desc-link"
|
||||
href={paperUrl}
|
||||
@ -34,14 +35,11 @@ function ModelSettingBlock() {
|
||||
{name}
|
||||
</a>
|
||||
|
||||
<br />
|
||||
|
||||
<a
|
||||
className="model-desc-link"
|
||||
href={githubUrl}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
style={{ marginTop: '8px' }}
|
||||
>
|
||||
{githubUrl}
|
||||
</a>
|
||||
@ -49,6 +47,28 @@ function ModelSettingBlock() {
|
||||
)
|
||||
}
|
||||
|
||||
const renderLDMModelDesc = () => {
|
||||
return (
|
||||
<div>
|
||||
{renderModelDesc(
|
||||
'High-Resolution Image Synthesis with Latent Diffusion Models',
|
||||
'https://arxiv.org/abs/2112.10752',
|
||||
'https://github.com/CompVis/latent-diffusion'
|
||||
)}
|
||||
<NumberInputSetting
|
||||
title="Steps"
|
||||
value={`${setting.ldmSteps}`}
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, ldmSteps: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderOptionDesc = (): ReactNode => {
|
||||
switch (setting.model) {
|
||||
case AIModel.LAMA:
|
||||
@ -58,11 +78,7 @@ function ModelSettingBlock() {
|
||||
'https://github.com/saic-mdal/lama'
|
||||
)
|
||||
case AIModel.LDM:
|
||||
return renderModelDesc(
|
||||
'High-Resolution Image Synthesis with Latent Diffusion Models',
|
||||
'https://arxiv.org/abs/2112.10752',
|
||||
'https://github.com/CompVis/latent-diffusion'
|
||||
)
|
||||
return renderLDMModelDesc()
|
||||
default:
|
||||
return <></>
|
||||
}
|
@ -0,0 +1,40 @@
|
||||
import React from 'react'
|
||||
import NumberInput from '../shared/NumberInput'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
interface NumberInputSettingProps {
|
||||
title: string
|
||||
value: string
|
||||
suffix?: string
|
||||
onValue: (val: string) => void
|
||||
}
|
||||
|
||||
function NumberInputSetting(props: NumberInputSettingProps) {
|
||||
const { title, value, suffix, onValue } = props
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
className="sub-setting-block"
|
||||
title={title}
|
||||
input={
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
gap: '8px',
|
||||
}}
|
||||
>
|
||||
<NumberInput
|
||||
style={{ width: '80px' }}
|
||||
value={`${value}`}
|
||||
onValue={onValue}
|
||||
/>
|
||||
{suffix && <span>{suffix}</span>}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default NumberInputSetting
|
@ -1,14 +1,26 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
function SavePathSettingBlock() {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const onCheckChange = (checked: boolean) => {
|
||||
setSettingState(old => {
|
||||
return { ...old, saveImageBesideOrigin: checked }
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
title="Download image beside origin image"
|
||||
input={
|
||||
<Switch defaultChecked>
|
||||
<Switch
|
||||
checked={setting.saveImageBesideOrigin}
|
||||
onCheckedChange={onCheckChange}
|
||||
>
|
||||
<SwitchThumb />
|
||||
</Switch>
|
||||
}
|
@ -7,7 +7,7 @@
|
||||
margin-top: 12px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 0.3rem;
|
||||
padding: 2rem;
|
||||
padding: 1rem;
|
||||
|
||||
.sub-setting-block {
|
||||
margin-top: 8px;
|
@ -8,7 +8,6 @@
|
||||
background-color: var(--modal-bg);
|
||||
color: var(--modal-text-color);
|
||||
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
||||
min-height: 600px;
|
||||
width: 700px;
|
||||
|
||||
@include mobile {
|
@ -1,11 +1,11 @@
|
||||
import React from 'react'
|
||||
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { switchModel } from '../../adapters/inpainting'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Modal from '../shared/Modal'
|
||||
import HDSettingBlock from './HDSettingBlock'
|
||||
import ModelSettingBlock from './ModelSettingBlock'
|
||||
import SavePathSettingBlock from './SavePathSettingBlock'
|
||||
|
||||
export default function SettingModal() {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
@ -14,6 +14,8 @@ export default function SettingModal() {
|
||||
setSettingState(old => {
|
||||
return { ...old, show: false }
|
||||
})
|
||||
|
||||
switchModel(setting.model)
|
||||
}
|
||||
|
||||
return (
|
||||
@ -23,7 +25,9 @@ export default function SettingModal() {
|
||||
className="modal-setting"
|
||||
show={setting.show}
|
||||
>
|
||||
<SavePathSettingBlock />
|
||||
{/* It's not possible because this poses a security risk */}
|
||||
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
|
||||
{/* <SavePathSettingBlock /> */}
|
||||
<ModelSettingBlock />
|
||||
<HDSettingBlock />
|
||||
</Modal>
|
@ -1,9 +1,7 @@
|
||||
import React from 'react'
|
||||
import { useRecoilValue } from 'recoil'
|
||||
import Editor from './Editor/Editor'
|
||||
import { shortcutsState } from '../store/Atoms'
|
||||
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
||||
import SettingModal from './Setting/SettingModal'
|
||||
import SettingModal from './Settings/SettingsModal'
|
||||
|
||||
interface WorkspaceProps {
|
||||
file: File
|
||||
|
@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
|
||||
const ref = useRef(null)
|
||||
|
||||
useClickAway(ref, () => {
|
||||
onClose?.()
|
||||
if (show) {
|
||||
onClose?.()
|
||||
}
|
||||
})
|
||||
|
||||
useKeyPressEvent('Escape', e => {
|
||||
onClose?.()
|
||||
if (show) {
|
||||
onClose?.()
|
||||
}
|
||||
})
|
||||
|
||||
return (
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { atom } from 'recoil'
|
||||
import { HDStrategy } from '../components/Setting/HDSettingBlock'
|
||||
import { AIModel } from '../components/Setting/ModelSettingBlock'
|
||||
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
||||
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||
|
||||
export const fileState = atom<File | undefined>({
|
||||
key: 'fileState',
|
||||
@ -16,10 +16,15 @@ export interface Setting {
|
||||
show: boolean
|
||||
saveImageBesideOrigin: boolean
|
||||
model: AIModel
|
||||
|
||||
// For LaMa
|
||||
hdStrategy: HDStrategy
|
||||
hdStrategyResizeLimit: number
|
||||
hdStrategyCropTrigerSize: number
|
||||
hdStrategyCropMargin: number
|
||||
|
||||
// For LDM
|
||||
ldmSteps: number
|
||||
}
|
||||
|
||||
export const settingState = atom<Setting>({
|
||||
@ -28,9 +33,10 @@ export const settingState = atom<Setting>({
|
||||
show: false,
|
||||
saveImageBesideOrigin: false,
|
||||
model: AIModel.LAMA,
|
||||
hdStrategy: HDStrategy.ORIGINAL,
|
||||
hdStrategy: HDStrategy.RESIZE,
|
||||
hdStrategyResizeLimit: 2048,
|
||||
hdStrategyCropTrigerSize: 2048,
|
||||
hdStrategyCropMargin: 128,
|
||||
ldmSteps: 50,
|
||||
},
|
||||
})
|
||||
|
@ -11,7 +11,7 @@
|
||||
@use '../components/Header/Header';
|
||||
@use '../components/Header/ThemeChanger';
|
||||
@use '../components/Shortcuts/Shortcuts';
|
||||
@use '../components/Setting/Setting.scss';
|
||||
@use '../components/Settings/Settings.scss';
|
||||
|
||||
// Shared
|
||||
@use '../components/FileSelect/FileSelect';
|
||||
|
@ -31,7 +31,11 @@ def ceil_modulo(x, mod):
|
||||
|
||||
|
||||
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||
data = cv2.imencode(f".{ext}", image_numpy)[1]
|
||||
data = cv2.imencode(f".{ext}", image_numpy,
|
||||
[
|
||||
int(cv2.IMWRITE_JPEG_QUALITY), 100,
|
||||
int(cv2.IMWRITE_PNG_COMPRESSION), 0
|
||||
])[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
@ -74,13 +78,24 @@ def resize_max_size(
|
||||
return np_img
|
||||
|
||||
|
||||
def pad_img_to_modulo(img, mod):
|
||||
channels, height, width = img.shape
|
||||
def pad_img_to_modulo(img: np.ndarray, mod: int):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img: [H, W, C]
|
||||
mod:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
height, width = img.shape[:2]
|
||||
out_height = ceil_modulo(height, mod)
|
||||
out_width = ceil_modulo(width, mod)
|
||||
return np.pad(
|
||||
img,
|
||||
((0, 0), (0, out_height - height), (0, out_width - width)),
|
||||
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||
mode="symmetric",
|
||||
)
|
||||
|
||||
@ -88,15 +103,13 @@ def pad_img_to_modulo(img, mod):
|
||||
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
mask: (1, h, w) 0~1
|
||||
mask: (h, w, 1) 0~255
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
height, width = mask.shape[1:]
|
||||
_, thresh = cv2.threshold(
|
||||
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
|
||||
)
|
||||
height, width = mask.shape[:2]
|
||||
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
boxes = []
|
||||
|
@ -1,121 +0,0 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
|
||||
|
||||
class LaMa:
|
||||
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
crop_trigger_size: h, w
|
||||
crop_margin:
|
||||
device:
|
||||
"""
|
||||
self.crop_trigger_size = crop_trigger_size
|
||||
self.crop_margin = crop_margin
|
||||
self.device = device
|
||||
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"lama torchscript model not found: {model_path}"
|
||||
)
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
|
||||
print(f"Load LaMa model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
area = image.shape[1] * image.shape[2]
|
||||
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
|
||||
return self._run(image, mask)
|
||||
|
||||
print("Trigger crop image")
|
||||
boxes = boxes_from_mask(mask)
|
||||
crop_result = []
|
||||
for box in boxes:
|
||||
crop_image, crop_box = self._run_box(image, mask, box)
|
||||
crop_result.append((crop_image, crop_box))
|
||||
|
||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
|
||||
for crop_image, crop_box in crop_result:
|
||||
x1, y1, x2, y2 = crop_box
|
||||
image[y1:y2, x1:x2, :] = crop_image
|
||||
return image
|
||||
|
||||
def _run_box(self, image, mask, box):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
box: [left,top,right,bottom]
|
||||
|
||||
Returns:
|
||||
BGR IMAGE
|
||||
"""
|
||||
box_h = box[3] - box[1]
|
||||
box_w = box[2] - box[0]
|
||||
cx = (box[0] + box[2]) // 2
|
||||
cy = (box[1] + box[3]) // 2
|
||||
img_h, img_w = image.shape[1:]
|
||||
|
||||
w = box_w + self.crop_margin * 2
|
||||
h = box_h + self.crop_margin * 2
|
||||
|
||||
l = max(cx - w // 2, 0)
|
||||
t = max(cy - h // 2, 0)
|
||||
r = min(cx + w // 2, img_w)
|
||||
b = min(cy + h // 2, img_h)
|
||||
|
||||
crop_img = image[:, t:b, l:r]
|
||||
crop_mask = mask[:, t:b, l:r]
|
||||
|
||||
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||
|
||||
return self._run(crop_img, crop_mask), [l, t, r, b]
|
||||
|
||||
def _run(self, image, mask):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
device = self.device
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=8)
|
||||
mask = pad_img_to_modulo(mask, mod=8)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||
return cur_res
|
122
lama_cleaner/model/base.py
Normal file
@ -0,0 +1,122 @@
|
||||
import abc
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
pad_mod = 8
|
||||
|
||||
def __init__(self, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
self.device = device
|
||||
self.init_model(device)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask, config: Config):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
|
||||
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
|
||||
result = self.forward(padd_image, padd_mask, config)
|
||||
result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
original_pixel_indices = mask != 255
|
||||
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [H, W, C] RGB, not normalized
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
inpaint_result = None
|
||||
logger.info(f"hd_strategy: {config.hd_strategy}")
|
||||
if config.hd_strategy == HDStrategy.CROP:
|
||||
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
||||
logger.info(f"Run crop strategy")
|
||||
boxes = boxes_from_mask(mask)
|
||||
crop_result = []
|
||||
for box in boxes:
|
||||
crop_image, crop_box = self._run_box(image, mask, box, config)
|
||||
crop_result.append((crop_image, crop_box))
|
||||
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
for crop_image, crop_box in crop_result:
|
||||
x1, y1, x2, y2 = crop_box
|
||||
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
||||
|
||||
elif config.hd_strategy == HDStrategy.RESIZE:
|
||||
if max(image.shape) > config.hd_strategy_resize_limit:
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
|
||||
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
|
||||
|
||||
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
|
||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
original_pixel_indices = mask != 255
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
|
||||
if inpaint_result is None:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def _run_box(self, image, mask, box, config: Config):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
box: [left,top,right,bottom]
|
||||
|
||||
Returns:
|
||||
BGR IMAGE
|
||||
"""
|
||||
box_h = box[3] - box[1]
|
||||
box_w = box[2] - box[0]
|
||||
cx = (box[0] + box[2]) // 2
|
||||
cy = (box[1] + box[3]) // 2
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
w = box_w + config.hd_strategy_crop_margin * 2
|
||||
h = box_h + config.hd_strategy_crop_margin * 2
|
||||
|
||||
l = max(cx - w // 2, 0)
|
||||
t = max(cy - h // 2, 0)
|
||||
r = min(cx + w // 2, img_w)
|
||||
b = min(cy + h // 2, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
|
||||
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||
|
||||
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
64
lama_cleaner/model/lama.py
Normal file
@ -0,0 +1,64 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
pad_mod = 8
|
||||
|
||||
def __init__(self, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device):
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"lama torchscript model not found: {model_path}"
|
||||
)
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
|
||||
logger.info(f"Load LaMa model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||
return cur_res
|
@ -2,13 +2,16 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model
|
||||
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
from lama_cleaner.helper import download_model, norm_img
|
||||
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
timestep_embedding
|
||||
|
||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||
@ -217,7 +220,7 @@ class DDIMSampler(object):
|
||||
|
||||
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
@ -268,97 +271,39 @@ def load_jit_model(url, device):
|
||||
return model
|
||||
|
||||
|
||||
class LDM:
|
||||
def __init__(self, device, steps=50):
|
||||
class LDM(InpaintModel):
|
||||
pad_mod = 32
|
||||
|
||||
def __init__(self, device):
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device):
|
||||
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
||||
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
||||
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
||||
|
||||
model = LatentDiffusion(self.diffusion_model, device)
|
||||
self.sampler = DDIMSampler(model)
|
||||
self.steps = steps
|
||||
|
||||
def _norm(self, tensor):
|
||||
return tensor * 2.0 - 1.0
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# image [1,3,512,512] float32
|
||||
# mask: [1,1,512,512] float32
|
||||
# masked_image: [1,3,512,512] float32
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=32)
|
||||
mask = pad_img_to_modulo(mask, mod=32)
|
||||
padded_height, padded_width = image.shape[1:]
|
||||
steps = config.ldm_steps
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# crop 512 x 512
|
||||
if padded_width <= 512 or padded_height <= 512:
|
||||
np_img = self._forward(image, mask, self.device)
|
||||
else:
|
||||
print("Try to zoom in")
|
||||
# zoom in
|
||||
# x,y,w,h
|
||||
# box = self.box_from_bitmap(mask)
|
||||
box = self.find_main_content(mask)
|
||||
if box is None:
|
||||
print("No bbox found")
|
||||
np_img = self._forward(image, mask, self.device)
|
||||
else:
|
||||
print(f"box: {box}")
|
||||
box_x, box_y, box_w, box_h = box
|
||||
cx = box_x + box_w // 2
|
||||
cy = box_y + box_h // 2
|
||||
|
||||
# w = max(512, box_w)
|
||||
# h = max(512, box_h)
|
||||
w = box_w + 512
|
||||
h = box_h + 512
|
||||
|
||||
left = max(cx - w // 2, 0)
|
||||
top = max(cy - h // 2, 0)
|
||||
right = min(cx + w // 2, origin_width)
|
||||
bottom = min(cy + h // 2, origin_height)
|
||||
|
||||
x = left
|
||||
y = top
|
||||
w = right - left
|
||||
h = bottom - top
|
||||
|
||||
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
|
||||
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
|
||||
|
||||
print(f"Apply zoom in size width x height: {crop_img.shape}")
|
||||
|
||||
crop_img_height, crop_img_width = crop_img.shape[1:]
|
||||
|
||||
crop_img = pad_img_to_modulo(crop_img, mod=32)
|
||||
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
|
||||
# RGB
|
||||
np_img = self._forward(crop_img, crop_mask, self.device)
|
||||
|
||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
|
||||
np_img = image
|
||||
# BGR to RGB
|
||||
# np_img = image[:, :, ::-1]
|
||||
|
||||
np_img = np_img[0:origin_height, 0:origin_width, :]
|
||||
np_img = np_img[:, :, ::-1]
|
||||
|
||||
return np_img
|
||||
|
||||
def _forward(self, image, mask, device):
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
masked_image = (1 - mask) * image
|
||||
|
||||
image = self._norm(image)
|
||||
@ -371,47 +316,20 @@ class LDM:
|
||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||
|
||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||
samples_ddim = self.sampler.sample(steps=self.steps,
|
||||
samples_ddim = self.sampler.sample(steps=steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape)
|
||||
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
|
||||
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
np_img = inpainted.astype(np.uint8)
|
||||
return np_img
|
||||
# inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
||||
return inpainted_image
|
||||
|
||||
def find_main_content(self, bitmap: np.ndarray):
|
||||
th2 = bitmap[0].astype(np.uint8)
|
||||
row_sum = th2.sum(1)
|
||||
col_sum = th2.sum(0)
|
||||
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
|
||||
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
|
||||
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
|
||||
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
|
||||
|
||||
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
|
||||
return left, top, right - left, bottom - top
|
||||
|
||||
def box_from_bitmap(self, bitmap):
|
||||
"""
|
||||
bitmap: single map with shape (NUM_CLASSES, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
contours, _ = cv2.findContours(
|
||||
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
|
||||
)
|
||||
|
||||
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
|
||||
num_contours = len(contours)
|
||||
print(f"contours size: {num_contours}")
|
||||
if num_contours != 1:
|
||||
return None
|
||||
|
||||
# x,y,w,h
|
||||
return cv2.boundingRect(contours[0])
|
||||
def _norm(self, tensor):
|
||||
return tensor * 2.0 - 1.0
|
34
lama_cleaner/model_manager.py
Normal file
@ -0,0 +1,34 @@
|
||||
from lama_cleaner.model.lama import LaMa
|
||||
from lama_cleaner.model.ldm import LDM
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class ModelManager:
|
||||
LAMA = 'lama'
|
||||
LDM = 'ldm'
|
||||
|
||||
def __init__(self, name: str, device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.model = self.init_model(name, device)
|
||||
|
||||
def init_model(self, name: str, device):
|
||||
if name == self.LAMA:
|
||||
model = LaMa(device)
|
||||
elif name == self.LDM:
|
||||
model = LDM(device)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
return model
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str):
|
||||
if new_name == self.name:
|
||||
return
|
||||
try:
|
||||
self.model = self.init_model(new_name, self.device)
|
||||
self.name = new_name
|
||||
except NotImplementedError as e:
|
||||
raise e
|
17
lama_cleaner/schema.py
Normal file
@ -0,0 +1,17 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HDStrategy(str, Enum):
|
||||
ORIGINAL = 'Original'
|
||||
RESIZE = 'Resize'
|
||||
CROP = 'Crop'
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
ldm_steps: int
|
||||
hd_strategy: str
|
||||
hd_strategy_crop_margin: int
|
||||
hd_strategy_crop_trigger_size: int
|
||||
hd_strategy_resize_limit: int
|
0
lama_cleaner/tests/__init__.py
Normal file
BIN
lama_cleaner/tests/image.png
Normal file
After Width: | Height: | Size: 129 KiB |
BIN
lama_cleaner/tests/lama_crop_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/lama_original_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/lama_resize_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_crop_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_original_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_resize_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
Before Width: | Height: | Size: 11 KiB |
BIN
lama_cleaner/tests/mask.png
Normal file
After Width: | Height: | Size: 7.7 KiB |
@ -1,15 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.helper import boxes_from_mask
|
||||
|
||||
|
||||
def test_boxes_from_mask():
|
||||
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
|
||||
mask = mask[:, :, np.newaxis]
|
||||
mask = (mask / 255).transpose(2, 0, 1)
|
||||
boxes = boxes_from_mask(mask)
|
||||
print(boxes)
|
||||
|
||||
|
||||
test_boxes_from_mask()
|
55
lama_cleaner/tests/test_model.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
|
||||
|
||||
def get_data():
|
||||
img = cv2.imread(str(current_dir / 'image.png'))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
|
||||
return img, mask
|
||||
|
||||
|
||||
def get_config(strategy):
|
||||
return Config(
|
||||
ldm_steps=1,
|
||||
hd_strategy=strategy,
|
||||
hd_strategy_crop_margin=32,
|
||||
hd_strategy_crop_trigger_size=200,
|
||||
hd_strategy_resize_limit=200,
|
||||
)
|
||||
|
||||
|
||||
def assert_equal(model, config, gt_name):
|
||||
img, mask = get_data()
|
||||
res = model(img, mask, config)
|
||||
# cv2.imwrite(gt_name, res,
|
||||
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
|
||||
|
||||
"""
|
||||
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
||||
when the saved image is reloaded, a difference occurs with the original pixel value.
|
||||
If you want to save the original image as it is, save it as PNG or BMP.
|
||||
"""
|
||||
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
|
||||
assert np.array_equal(res, gt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
def test_lama(strategy):
|
||||
model = ModelManager(name='lama', device='cpu')
|
||||
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
def test_ldm(strategy):
|
||||
model = ModelManager(name='ldm', device='cpu')
|
||||
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')
|
94
main.py
@ -2,19 +2,21 @@
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import imghdr
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from lama_cleaner.lama import LaMa
|
||||
from lama_cleaner.ldm import LDM
|
||||
from loguru import logger
|
||||
|
||||
from flaskwebgui import FlaskUI
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
@ -29,7 +31,6 @@ from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
norm_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
)
|
||||
@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"):
|
||||
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
logger_opt = logger.opt(depth=6, exception=record.exc_info)
|
||||
logger_opt.log(record.levelno, record.getMessage())
|
||||
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app)
|
||||
app.logger.addHandler(InterceptHandler())
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
|
||||
model = None
|
||||
model: ModelManager = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
|
||||
@ -72,24 +81,31 @@ def process():
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
||||
form = request.form
|
||||
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
|
||||
if size_limit == "Original":
|
||||
size_limit = max(image.shape)
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
print(f"Origin image shape: {original_shape}")
|
||||
config = Config(
|
||||
ldm_steps=form['ldmSteps'],
|
||||
hd_strategy=form['hdStrategy'],
|
||||
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
|
||||
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
|
||||
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
|
||||
)
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
print(f"Resized image shape: {image.shape}")
|
||||
image = norm_img(image)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
mask = norm_img(mask)
|
||||
|
||||
start = time.time()
|
||||
res_np_img = model(image, mask)
|
||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
||||
res_np_img = model(image, mask, config)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -109,6 +125,19 @@ def process():
|
||||
)
|
||||
|
||||
|
||||
@app.route("/switch_model", methods=["POST"])
|
||||
def switch_model():
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
except NotImplementedError:
|
||||
return f"{new_name} not implemented", 403
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
@ -120,7 +149,9 @@ def set_input_photo():
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
io.BytesIO(image_in_bytes),
|
||||
input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
@ -135,29 +166,6 @@ def get_args_parser():
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument(
|
||||
"--crop-trigger-size",
|
||||
default=[2042, 2042],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="If image size large then crop-trigger-size, "
|
||||
"crop each area from original image to do inference."
|
||||
"Mainly for performance and memory reasons"
|
||||
"Only for lama",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-margin",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Margin around bounding box of painted stroke when crop mode triggered",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ldm-steps",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Steps for DDIM sampling process."
|
||||
"The larger the value, the better the result, but it will be more time-consuming",
|
||||
)
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
@ -188,19 +196,11 @@ def main():
|
||||
device = torch.device(args.device)
|
||||
input_image_path = args.input
|
||||
|
||||
if args.model == "lama":
|
||||
model = LaMa(
|
||||
crop_trigger_size=args.crop_trigger_size,
|
||||
crop_margin=args.crop_margin,
|
||||
device=device,
|
||||
)
|
||||
elif args.model == "ldm":
|
||||
model = LDM(device, steps=args.ldm_steps)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {args.model}")
|
||||
model = ModelManager(name=args.model, device=device)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
ui = FlaskUI(app, width=app_width, height=app_height)
|
||||
ui.run()
|
||||
else:
|
||||
|
@ -1,6 +1,9 @@
|
||||
torch>=1.8.2
|
||||
opencv-python
|
||||
flask_cors
|
||||
flask
|
||||
flask==2.1.1
|
||||
flaskwebgui
|
||||
tqdm
|
||||
pydantic
|
||||
loguru
|
||||
pytest
|
||||
|