big update

This commit is contained in:
Sanster 2022-04-16 00:11:51 +08:00
parent 2b031603ed
commit 205286a414
40 changed files with 539 additions and 376 deletions

View File

@ -1,10 +1,12 @@
import { Setting } from '../store/Atoms'
import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
export default async function inpaint(
imageFile: File,
maskBase64: string,
settings: Setting,
sizeLimit?: string
) {
// 1080, 2000, Original
@ -13,13 +15,22 @@ export default async function inpaint(
const mask = dataURItoBlob(maskBase64)
fd.append('mask', mask)
fd.append('ldmSteps', settings.ldmSteps.toString())
fd.append('hdStrategy', settings.hdStrategy)
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
fd.append(
'hdStrategyCropTrigerSize',
settings.hdStrategyCropTrigerSize.toString()
)
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {
fd.append('sizeLimit', sizeLimit)
}
const res = await fetch(API_ENDPOINT, {
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: 'POST',
body: fd,
}).then(async r => {
@ -28,3 +39,12 @@ export default async function inpaint(
return URL.createObjectURL(res)
}
export function switchModel(name: string) {
const fd = new FormData()
fd.append('name', name)
return fetch(`${API_ENDPOINT}/switch_model`, {
method: 'POST',
body: fd,
})
}

View File

@ -15,12 +15,14 @@ import {
TransformComponent,
TransformWrapper,
} from 'react-zoom-pan-pinch'
import { useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint from '../../adapters/inpainting'
import Button from '../shared/Button'
import Slider from './Slider'
import SizeSelector from './SizeSelector'
import { downloadImage, loadImage, useImage } from '../../utils'
import { settingState } from '../../store/Atoms'
const TOOLBAR_SIZE = 200
const BRUSH_COLOR = '#ffcc00bb'
@ -57,6 +59,7 @@ function drawLines(
export default function Editor(props: EditorProps) {
const { file } = props
const settings = useRecoilValue(settingState)
const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([])
@ -125,6 +128,7 @@ export default function Editor(props: EditorProps) {
const res = await inpaint(
file,
maskCanvas.toDataURL(),
settings,
sizeLimit.toString()
)
if (!res) {
@ -157,6 +161,7 @@ export default function Editor(props: EditorProps) {
renders,
sizeLimit,
historyLineCount,
settings,
])
const hadDrawSomething = () => {

View File

@ -6,7 +6,7 @@ import Button from '../shared/Button'
import Shortcuts from '../Shortcuts/Shortcuts'
import useResolution from '../../hooks/useResolution'
import { ThemeChanger } from './ThemeChanger'
import SettingIcon from '../Setting/SettingIcon'
import SettingIcon from '../Settings/SettingIcon'
const Header = () => {
const [file, setFile] = useRecoilState(fileState)

View File

@ -1,50 +1,16 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import NumberInput from '../shared/NumberInput'
import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock'
export enum HDStrategy {
ORIGINAL = 'Original',
REISIZE = 'Resize',
RESIZE = 'Resize',
CROP = 'Crop',
}
interface PixelSizeInputProps {
title: string
value: string
onValue: (val: string) => void
}
function PixelSizeInputSetting(props: PixelSizeInputProps) {
const { title, value, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
<span>pixel</span>
</div>
}
/>
)
}
function HDSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
@ -84,7 +50,7 @@ function HDSettingBlock() {
tabIndex={0}
role="button"
className="inline-tip"
onClick={() => onStrategyChange(HDStrategy.REISIZE)}
onClick={() => onStrategyChange(HDStrategy.RESIZE)}
>
Resize Strategy
</div>{' '}
@ -100,9 +66,10 @@ function HDSettingBlock() {
Resize the longer side of the image to a specific size(keep ratio),
then do inpainting on the resized image.
</div>
<PixelSizeInputSetting
<NumberInputSetting
title="Size limit"
value={`${setting.hdStrategyResizeLimit}`}
suffix="pixel"
onValue={onResizeLimitChange}
/>
</div>
@ -117,14 +84,16 @@ function HDSettingBlock() {
the result back. Mainly for performance and memory reasons on high
resolution image.
</div>
<PixelSizeInputSetting
<NumberInputSetting
title="Trigger size"
value={`${setting.hdStrategyCropTrigerSize}`}
suffix="pixel"
onValue={onCropTriggerSizeChange}
/>
<PixelSizeInputSetting
<NumberInputSetting
title="Crop margin"
value={`${setting.hdStrategyCropMargin}`}
suffix="pixel"
onValue={onCropMarginChange}
/>
</div>
@ -137,7 +106,7 @@ function HDSettingBlock() {
return renderOriginalOptionDesc()
case HDStrategy.CROP:
return renderCropOptionDesc()
case HDStrategy.REISIZE:
case HDStrategy.RESIZE:
return renderResizeOptionDesc()
default:
return renderOriginalOptionDesc()

View File

@ -2,11 +2,12 @@ import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock'
export enum AIModel {
LAMA = 'LaMa',
LDM = 'LDM',
LAMA = 'lama',
LDM = 'ldm',
}
function ModelSettingBlock() {
@ -24,7 +25,7 @@ function ModelSettingBlock() {
githubUrl: string
) => {
return (
<div style={{ display: 'flex', flexDirection: 'column' }}>
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
<a
className="model-desc-link"
href={paperUrl}
@ -34,14 +35,11 @@ function ModelSettingBlock() {
{name}
</a>
<br />
<a
className="model-desc-link"
href={githubUrl}
target="_blank"
rel="noreferrer noopener"
style={{ marginTop: '8px' }}
>
{githubUrl}
</a>
@ -49,6 +47,28 @@ function ModelSettingBlock() {
)
}
const renderLDMModelDesc = () => {
return (
<div>
{renderModelDesc(
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)}
<NumberInputSetting
title="Steps"
value={`${setting.ldmSteps}`}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, ldmSteps: val }
})
}}
/>
</div>
)
}
const renderOptionDesc = (): ReactNode => {
switch (setting.model) {
case AIModel.LAMA:
@ -58,11 +78,7 @@ function ModelSettingBlock() {
'https://github.com/saic-mdal/lama'
)
case AIModel.LDM:
return renderModelDesc(
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)
return renderLDMModelDesc()
default:
return <></>
}

View File

@ -0,0 +1,40 @@
import React from 'react'
import NumberInput from '../shared/NumberInput'
import SettingBlock from './SettingBlock'
interface NumberInputSettingProps {
title: string
value: string
suffix?: string
onValue: (val: string) => void
}
function NumberInputSetting(props: NumberInputSettingProps) {
const { title, value, suffix, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
{suffix && <span>{suffix}</span>}
</div>
}
/>
)
}
export default NumberInputSetting

View File

@ -1,14 +1,26 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import { Switch, SwitchThumb } from '../shared/Switch'
import SettingBlock from './SettingBlock'
function SavePathSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
const onCheckChange = (checked: boolean) => {
setSettingState(old => {
return { ...old, saveImageBesideOrigin: checked }
})
}
return (
<SettingBlock
title="Download image beside origin image"
input={
<Switch defaultChecked>
<Switch
checked={setting.saveImageBesideOrigin}
onCheckedChange={onCheckChange}
>
<SwitchThumb />
</Switch>
}

View File

@ -7,7 +7,7 @@
margin-top: 12px;
border: 1px solid var(--border-color);
border-radius: 0.3rem;
padding: 2rem;
padding: 1rem;
.sub-setting-block {
margin-top: 8px;

View File

@ -8,7 +8,6 @@
background-color: var(--modal-bg);
color: var(--modal-text-color);
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
min-height: 600px;
width: 700px;
@include mobile {

View File

@ -1,11 +1,11 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { switchModel } from '../../adapters/inpainting'
import { settingState } from '../../store/Atoms'
import Modal from '../shared/Modal'
import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock'
import SavePathSettingBlock from './SavePathSettingBlock'
export default function SettingModal() {
const [setting, setSettingState] = useRecoilState(settingState)
@ -14,6 +14,8 @@ export default function SettingModal() {
setSettingState(old => {
return { ...old, show: false }
})
switchModel(setting.model)
}
return (
@ -23,7 +25,9 @@ export default function SettingModal() {
className="modal-setting"
show={setting.show}
>
<SavePathSettingBlock />
{/* It's not possible because this poses a security risk */}
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
{/* <SavePathSettingBlock /> */}
<ModelSettingBlock />
<HDSettingBlock />
</Modal>

View File

@ -1,9 +1,7 @@
import React from 'react'
import { useRecoilValue } from 'recoil'
import Editor from './Editor/Editor'
import { shortcutsState } from '../store/Atoms'
import ShortcutsModal from './Shortcuts/ShortcutsModal'
import SettingModal from './Setting/SettingModal'
import SettingModal from './Settings/SettingsModal'
interface WorkspaceProps {
file: File

View File

@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
const ref = useRef(null)
useClickAway(ref, () => {
onClose?.()
if (show) {
onClose?.()
}
})
useKeyPressEvent('Escape', e => {
onClose?.()
if (show) {
onClose?.()
}
})
return (

View File

@ -1,6 +1,6 @@
import { atom } from 'recoil'
import { HDStrategy } from '../components/Setting/HDSettingBlock'
import { AIModel } from '../components/Setting/ModelSettingBlock'
import { HDStrategy } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Settings/ModelSettingBlock'
export const fileState = atom<File | undefined>({
key: 'fileState',
@ -16,10 +16,15 @@ export interface Setting {
show: boolean
saveImageBesideOrigin: boolean
model: AIModel
// For LaMa
hdStrategy: HDStrategy
hdStrategyResizeLimit: number
hdStrategyCropTrigerSize: number
hdStrategyCropMargin: number
// For LDM
ldmSteps: number
}
export const settingState = atom<Setting>({
@ -28,9 +33,10 @@ export const settingState = atom<Setting>({
show: false,
saveImageBesideOrigin: false,
model: AIModel.LAMA,
hdStrategy: HDStrategy.ORIGINAL,
hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048,
hdStrategyCropMargin: 128,
ldmSteps: 50,
},
})

View File

@ -11,7 +11,7 @@
@use '../components/Header/Header';
@use '../components/Header/ThemeChanger';
@use '../components/Shortcuts/Shortcuts';
@use '../components/Setting/Setting.scss';
@use '../components/Settings/Settings.scss';
// Shared
@use '../components/FileSelect/FileSelect';

View File

@ -31,7 +31,11 @@ def ceil_modulo(x, mod):
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(f".{ext}", image_numpy)[1]
data = cv2.imencode(f".{ext}", image_numpy,
[
int(cv2.IMWRITE_JPEG_QUALITY), 100,
int(cv2.IMWRITE_PNG_COMPRESSION), 0
])[1]
image_bytes = data.tobytes()
return image_bytes
@ -74,13 +78,24 @@ def resize_max_size(
return np_img
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
def pad_img_to_modulo(img: np.ndarray, mod: int):
"""
Args:
img: [H, W, C]
mod:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(
img,
((0, 0), (0, out_height - height), (0, out_width - width)),
((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric",
)
@ -88,15 +103,13 @@ def pad_img_to_modulo(img, mod):
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (1, h, w) 0~1
mask: (h, w, 1) 0~255
Returns:
"""
height, width = mask.shape[1:]
_, thresh = cv2.threshold(
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
)
height, width = mask.shape[:2]
_, thresh = cv2.threshold(mask, 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []

View File

@ -1,121 +0,0 @@
import os
from typing import List
import cv2
import torch
import numpy as np
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa:
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
"""
Args:
crop_trigger_size: h, w
crop_margin:
device:
"""
self.crop_trigger_size = crop_trigger_size
self.crop_margin = crop_margin
self.device = device
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
print(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
area = image.shape[1] * image.shape[2]
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
return self._run(image, mask)
print("Trigger crop image")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box)
crop_result.append((crop_image, crop_box))
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
image[y1:y2, x1:x2, :] = crop_image
return image
def _run_box(self, image, mask, box):
"""
Args:
image: [C, H, W] RGB
mask: [1, H, W]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[1:]
w = box_w + self.crop_margin * 2
h = box_h + self.crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[:, t:b, l:r]
crop_mask = mask[:, t:b, l:r]
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._run(crop_img, crop_mask), [l, t, r, b]
def _run(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
device = self.device
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

122
lama_cleaner/model/base.py Normal file
View File

@ -0,0 +1,122 @@
import abc
import cv2
import torch
from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
from lama_cleaner.schema import Config, HDStrategy
class InpaintModel:
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
self.device = device
self.init_model(device)
@abc.abstractmethod
def init_model(self, device):
...
@abc.abstractmethod
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
result = self.forward(padd_image, padd_mask, config)
result = result[0:origin_height, 0:origin_width, :]
original_pixel_indices = mask != 255
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
image: [H, W, C] RGB, not normalized
mask: [H, W]
return: BGR IMAGE
"""
inpaint_result = None
logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info(f"Run crop strategy")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box, config)
crop_result.append((crop_image, crop_box))
inpaint_result = image[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC)
original_pixel_indices = mask != 255
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
def _run_box(self, image, mask, box, config: Config):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[:2]
w = box_w + config.hd_strategy_crop_margin * 2
h = box_h + config.hd_strategy_crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]

View File

@ -0,0 +1,64 @@
import os
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa(InpaintModel):
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
super().__init__(device)
self.device = device
def init_model(self, device):
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
logger.info(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

View File

@ -2,13 +2,16 @@ import os
import numpy as np
import torch
from loguru import logger
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
torch.manual_seed(42)
import torch.nn as nn
from tqdm import tqdm
import cv2
from lama_cleaner.helper import pad_img_to_modulo, download_model
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
from lama_cleaner.helper import download_model, norm_img
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
timestep_embedding
LDM_ENCODE_MODEL_URL = os.environ.get(
@ -217,7 +220,7 @@ class DDIMSampler(object):
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
@ -268,97 +271,39 @@ def load_jit_model(url, device):
return model
class LDM:
def __init__(self, device, steps=50):
class LDM(InpaintModel):
pad_mod = 32
def __init__(self, device):
super().__init__(device)
self.device = device
def init_model(self, device):
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
self.steps = steps
def _norm(self, tensor):
return tensor * 2.0 - 1.0
@torch.no_grad()
def __call__(self, image, mask):
def forward(self, image, mask, config: Config):
"""
image: [C, H, W] RGB
mask: [1, H, W]
image: [H, W, C] RGB
mask: [H, W, 1]
return: BGR IMAGE
"""
# image [1,3,512,512] float32
# mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=32)
mask = pad_img_to_modulo(mask, mod=32)
padded_height, padded_width = image.shape[1:]
steps = config.ldm_steps
image = norm_img(image)
mask = norm_img(mask)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# crop 512 x 512
if padded_width <= 512 or padded_height <= 512:
np_img = self._forward(image, mask, self.device)
else:
print("Try to zoom in")
# zoom in
# x,y,w,h
# box = self.box_from_bitmap(mask)
box = self.find_main_content(mask)
if box is None:
print("No bbox found")
np_img = self._forward(image, mask, self.device)
else:
print(f"box: {box}")
box_x, box_y, box_w, box_h = box
cx = box_x + box_w // 2
cy = box_y + box_h // 2
# w = max(512, box_w)
# h = max(512, box_h)
w = box_w + 512
h = box_h + 512
left = max(cx - w // 2, 0)
top = max(cy - h // 2, 0)
right = min(cx + w // 2, origin_width)
bottom = min(cy + h // 2, origin_height)
x = left
y = top
w = right - left
h = bottom - top
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
print(f"Apply zoom in size width x height: {crop_img.shape}")
crop_img_height, crop_img_width = crop_img.shape[1:]
crop_img = pad_img_to_modulo(crop_img, mod=32)
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
# RGB
np_img = self._forward(crop_img, crop_mask, self.device)
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
np_img = image
# BGR to RGB
# np_img = image[:, :, ::-1]
np_img = np_img[0:origin_height, 0:origin_width, :]
np_img = np_img[:, :, ::-1]
return np_img
def _forward(self, image, mask, device):
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image
image = self._norm(image)
@ -371,47 +316,20 @@ class LDM:
c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample(steps=self.steps,
samples_ddim = self.sampler.sample(steps=steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape)
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
np_img = inpainted.astype(np.uint8)
return np_img
# inpainted = (1 - mask) * image + mask * predicted_image
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
return inpainted_image
def find_main_content(self, bitmap: np.ndarray):
th2 = bitmap[0].astype(np.uint8)
row_sum = th2.sum(1)
col_sum = th2.sum(0)
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
return left, top, right - left, bottom - top
def box_from_bitmap(self, bitmap):
"""
bitmap: single map with shape (NUM_CLASSES, H, W),
whose values are binarized as {0, 1}
"""
contours, _ = cv2.findContours(
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
)
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
num_contours = len(contours)
print(f"contours size: {num_contours}")
if num_contours != 1:
return None
# x,y,w,h
return cv2.boundingRect(contours[0])
def _norm(self, tensor):
return tensor * 2.0 - 1.0

View File

@ -0,0 +1,34 @@
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.schema import Config
class ModelManager:
LAMA = 'lama'
LDM = 'ldm'
def __init__(self, name: str, device):
self.name = name
self.device = device
self.model = self.init_model(name, device)
def init_model(self, name: str, device):
if name == self.LAMA:
model = LaMa(device)
elif name == self.LDM:
model = LDM(device)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)
def switch(self, new_name: str):
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.name = new_name
except NotImplementedError as e:
raise e

17
lama_cleaner/schema.py Normal file
View File

@ -0,0 +1,17 @@
from enum import Enum
from pydantic import BaseModel
class HDStrategy(str, Enum):
ORIGINAL = 'Original'
RESIZE = 'Resize'
CROP = 'Crop'
class Config(BaseModel):
ldm_steps: int
hd_strategy: str
hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int
hd_strategy_resize_limit: int

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

BIN
lama_cleaner/tests/mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@ -1,15 +0,0 @@
import cv2
import numpy as np
from lama_cleaner.helper import boxes_from_mask
def test_boxes_from_mask():
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
mask = mask[:, :, np.newaxis]
mask = (mask / 255).transpose(2, 0, 1)
boxes = boxes_from_mask(mask)
print(boxes)
test_boxes_from_mask()

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import cv2
import numpy as np
import pytest
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
def get_data():
img = cv2.imread(str(current_dir / 'image.png'))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
return img, mask
def get_config(strategy):
return Config(
ldm_steps=1,
hd_strategy=strategy,
hd_strategy_crop_margin=32,
hd_strategy_crop_trigger_size=200,
hd_strategy_resize_limit=200,
)
def assert_equal(model, config, gt_name):
img, mask = get_data()
res = model(img, mask, config)
# cv2.imwrite(gt_name, res,
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
"""
Note that JPEG is lossy compression, so even if it is the highest quality 100,
when the saved image is reloaded, a difference occurs with the original pixel value.
If you want to save the original image as it is, save it as PNG or BMP.
"""
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
assert np.array_equal(res, gt)
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_lama(strategy):
model = ModelManager(name='lama', device='cpu')
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_ldm(strategy):
model = ModelManager(name='ldm', device='cpu')
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')

94
main.py
View File

@ -2,19 +2,21 @@
import argparse
import io
import logging
import multiprocessing
import os
import time
import imghdr
from pathlib import Path
from typing import Union
import cv2
import torch
import numpy as np
from lama_cleaner.lama import LaMa
from lama_cleaner.ldm import LDM
from loguru import logger
from flaskwebgui import FlaskUI
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
@ -29,7 +31,6 @@ from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
norm_img,
numpy_to_bytes,
resize_max_size,
)
@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"):
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
class InterceptHandler(logging.Handler):
def emit(self, record):
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(record.levelno, record.getMessage())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app)
app.logger.addHandler(InterceptHandler())
CORS(app, expose_headers=["Content-Disposition"])
model = None
model: ModelManager = None
device = None
input_image_path: str = None
@ -72,24 +81,31 @@ def process():
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}")
config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image.shape}")
image = norm_img(image)
logger.info(f"Resized image shape: {image.shape}")
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask)
start = time.time()
res_np_img = model(image, mask)
print(f"process time: {(time.time() - start) * 1000}ms")
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
@ -109,6 +125,19 @@ def process():
)
@app.route("/switch_model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@ -120,7 +149,9 @@ def set_input_photo():
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
io.BytesIO(image_in_bytes),
input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
@ -135,29 +166,6 @@ def get_args_parser():
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument(
"--crop-trigger-size",
default=[2042, 2042],
nargs=2,
type=int,
help="If image size large then crop-trigger-size, "
"crop each area from original image to do inference."
"Mainly for performance and memory reasons"
"Only for lama",
)
parser.add_argument(
"--crop-margin",
type=int,
default=256,
help="Margin around bounding box of painted stroke when crop mode triggered",
)
parser.add_argument(
"--ldm-steps",
default=50,
type=int,
help="Steps for DDIM sampling process."
"The larger the value, the better the result, but it will be more time-consuming",
)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
@ -188,19 +196,11 @@ def main():
device = torch.device(args.device)
input_image_path = args.input
if args.model == "lama":
model = LaMa(
crop_trigger_size=args.crop_trigger_size,
crop_margin=args.crop_margin,
device=device,
)
elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps)
else:
raise NotImplementedError(f"Not supported model: {args.model}")
model = ModelManager(name=args.model, device=device)
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
ui = FlaskUI(app, width=app_width, height=app_height)
ui.run()
else:

View File

@ -1,6 +1,9 @@
torch>=1.8.2
opencv-python
flask_cors
flask
flask==2.1.1
flaskwebgui
tqdm
pydantic
loguru
pytest