resize image using backend;add resize radio button
frontend resize image will reduce image quality
This commit is contained in:
parent
1c2e7fa559
commit
1e2c8fd348
@ -4,6 +4,7 @@
|
|||||||
"private": true,
|
"private": true,
|
||||||
"proxy": "http://localhost:8080",
|
"proxy": "http://localhost:8080",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@headlessui/react": "^1.4.2",
|
||||||
"@heroicons/react": "^1.0.4",
|
"@heroicons/react": "^1.0.4",
|
||||||
"@testing-library/jest-dom": "^5.14.1",
|
"@testing-library/jest-dom": "^5.14.1",
|
||||||
"@testing-library/react": "^12.1.2",
|
"@testing-library/react": "^12.1.2",
|
||||||
|
@ -4,7 +4,6 @@ import { useWindowSize } from 'react-use'
|
|||||||
import Button from './components/Button'
|
import Button from './components/Button'
|
||||||
import FileSelect from './components/FileSelect'
|
import FileSelect from './components/FileSelect'
|
||||||
import Editor from './Editor'
|
import Editor from './Editor'
|
||||||
import { resizeImageFile } from './utils'
|
|
||||||
|
|
||||||
function App() {
|
function App() {
|
||||||
const [file, setFile] = useState<File>()
|
const [file, setFile] = useState<File>()
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import { DownloadIcon, EyeIcon } from '@heroicons/react/outline'
|
import { DownloadIcon, EyeIcon } from '@heroicons/react/outline'
|
||||||
import React, { useCallback, useEffect, useState } from 'react'
|
import React, { useCallback, useEffect, useState } from 'react'
|
||||||
import { useWindowSize } from 'react-use'
|
import { useWindowSize, useLocalStorage } from 'react-use'
|
||||||
import inpaint from './adapters/inpainting'
|
import inpaint from './adapters/inpainting'
|
||||||
import Button from './components/Button'
|
import Button from './components/Button'
|
||||||
import Slider from './components/Slider'
|
import Slider from './components/Slider'
|
||||||
import { downloadImage, loadImage, shareImage, useImage } from './utils'
|
import SizeSelector from './components/SizeSelector'
|
||||||
|
import { downloadImage, loadImage, useImage } from './utils'
|
||||||
|
|
||||||
const TOOLBAR_SIZE = 200
|
const TOOLBAR_SIZE = 200
|
||||||
const BRUSH_COLOR = 'rgba(189, 255, 1, 0.75)'
|
const BRUSH_COLOR = 'rgba(189, 255, 1, 0.75)'
|
||||||
@ -55,6 +56,8 @@ export default function Editor(props: EditorProps) {
|
|||||||
const [isInpaintingLoading, setIsInpaintingLoading] = useState(false)
|
const [isInpaintingLoading, setIsInpaintingLoading] = useState(false)
|
||||||
const [showSeparator, setShowSeparator] = useState(false)
|
const [showSeparator, setShowSeparator] = useState(false)
|
||||||
const [scale, setScale] = useState(1)
|
const [scale, setScale] = useState(1)
|
||||||
|
// ['1080', '2000', 'Original']
|
||||||
|
const [sizeLimit, setSizeLimit] = useLocalStorage('sizeLimit', '1080')
|
||||||
const windowSize = useWindowSize()
|
const windowSize = useWindowSize()
|
||||||
|
|
||||||
const draw = useCallback(() => {
|
const draw = useCallback(() => {
|
||||||
@ -144,8 +147,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
window.removeEventListener('mouseup', onPointerUp)
|
window.removeEventListener('mouseup', onPointerUp)
|
||||||
refreshCanvasMask()
|
refreshCanvasMask()
|
||||||
try {
|
try {
|
||||||
const start = Date.now()
|
const res = await inpaint(file, maskCanvas.toDataURL(), sizeLimit)
|
||||||
const res = await inpaint(file, maskCanvas.toDataURL())
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
throw new Error('empty response')
|
throw new Error('empty response')
|
||||||
}
|
}
|
||||||
@ -221,6 +223,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
original.naturalHeight,
|
original.naturalHeight,
|
||||||
original.naturalWidth,
|
original.naturalWidth,
|
||||||
scale,
|
scale,
|
||||||
|
sizeLimit,
|
||||||
])
|
])
|
||||||
|
|
||||||
const undo = useCallback(() => {
|
const undo = useCallback(() => {
|
||||||
@ -252,12 +255,16 @@ export default function Editor(props: EditorProps) {
|
|||||||
}, [renders, undo])
|
}, [renders, undo])
|
||||||
|
|
||||||
function download() {
|
function download() {
|
||||||
const base64 = context?.canvas.toDataURL(file.type)
|
|
||||||
if (!base64) {
|
|
||||||
throw new Error('could not get canvas data')
|
|
||||||
}
|
|
||||||
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
|
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
|
||||||
downloadImage(base64, name)
|
const currRender = renders[renders.length - 1]
|
||||||
|
downloadImage(currRender.currentSrc, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
const onSizeLimitChange = (_sizeLimit: string) => {
|
||||||
|
// TODO: clean renders
|
||||||
|
// if (renders.length !== 0) {
|
||||||
|
// }
|
||||||
|
setSizeLimit(_sizeLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -337,7 +344,7 @@ export default function Editor(props: EditorProps) {
|
|||||||
|
|
||||||
<div
|
<div
|
||||||
className={[
|
className={[
|
||||||
'flex items-center w-full max-w-3xl',
|
'flex items-center w-full max-w-5xl',
|
||||||
'space-x-3 sm:space-x-5',
|
'space-x-3 sm:space-x-5',
|
||||||
'p-6',
|
'p-6',
|
||||||
scale !== 1
|
scale !== 1
|
||||||
@ -345,10 +352,15 @@ export default function Editor(props: EditorProps) {
|
|||||||
: 'relative justify-evenly sm:justify-between',
|
: 'relative justify-evenly sm:justify-between',
|
||||||
].join(' ')}
|
].join(' ')}
|
||||||
>
|
>
|
||||||
|
<SizeSelector
|
||||||
|
value={sizeLimit}
|
||||||
|
onChange={onSizeLimitChange}
|
||||||
|
originalSize={`${original.naturalWidth}x${original.naturalHeight}`}
|
||||||
|
/>
|
||||||
<Slider
|
<Slider
|
||||||
label={
|
label={
|
||||||
<span>
|
<span>
|
||||||
<span className="hidden md:inline">Brush</span> Size
|
<span className="hidden md:inline">Brush</span>
|
||||||
</span>
|
</span>
|
||||||
}
|
}
|
||||||
min={10}
|
min={10}
|
||||||
|
@ -2,12 +2,23 @@ import { dataURItoBlob } from '../utils'
|
|||||||
|
|
||||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
|
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
|
||||||
|
|
||||||
export default async function inpaint(imageFile: File, maskBase64: string) {
|
export default async function inpaint(
|
||||||
|
imageFile: File,
|
||||||
|
maskBase64: string,
|
||||||
|
sizeLimit?: string
|
||||||
|
) {
|
||||||
|
// 1080, 2000, Original
|
||||||
const fd = new FormData()
|
const fd = new FormData()
|
||||||
fd.append('image', imageFile)
|
fd.append('image', imageFile)
|
||||||
const mask = dataURItoBlob(maskBase64)
|
const mask = dataURItoBlob(maskBase64)
|
||||||
fd.append('mask', mask)
|
fd.append('mask', mask)
|
||||||
|
|
||||||
|
if (sizeLimit === undefined) {
|
||||||
|
fd.append('sizeLimit', '1080')
|
||||||
|
} else {
|
||||||
|
fd.append('sizeLimit', sizeLimit)
|
||||||
|
}
|
||||||
|
|
||||||
const res = await fetch(API_ENDPOINT, {
|
const res = await fetch(API_ENDPOINT, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: fd,
|
body: fd,
|
||||||
|
File diff suppressed because one or more lines are too long
56
lama_cleaner/app/src/components/SizeSelector.tsx
Normal file
56
lama_cleaner/app/src/components/SizeSelector.tsx
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { RadioGroup } from '@headlessui/react'
|
||||||
|
|
||||||
|
const sizes = [
|
||||||
|
['1080', '1080'],
|
||||||
|
['2000', '2k'],
|
||||||
|
['Original', 'Original'],
|
||||||
|
]
|
||||||
|
|
||||||
|
type SizeSelectorProps = {
|
||||||
|
value?: string
|
||||||
|
originalSize: string
|
||||||
|
onChange: (value: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function SizeSelector(props: SizeSelectorProps) {
|
||||||
|
const { value, originalSize, onChange } = props
|
||||||
|
|
||||||
|
return (
|
||||||
|
<RadioGroup
|
||||||
|
className="my-4 flex items-center space-x-2"
|
||||||
|
value={value}
|
||||||
|
onChange={onChange}
|
||||||
|
>
|
||||||
|
<RadioGroup.Label>Resize</RadioGroup.Label>
|
||||||
|
{sizes.map(size => (
|
||||||
|
<RadioGroup.Option key={size[0]} value={size[0]}>
|
||||||
|
{({ checked }) => (
|
||||||
|
<div
|
||||||
|
className={[
|
||||||
|
checked ? 'bg-gray-200' : 'border-opacity-10',
|
||||||
|
'border-3 px-2 py-2 rounded-md',
|
||||||
|
].join(' ')}
|
||||||
|
>
|
||||||
|
<div className="flex items-center space-x-4">
|
||||||
|
<div
|
||||||
|
className={[
|
||||||
|
'rounded-full w-5 h-5 border-4 ',
|
||||||
|
checked
|
||||||
|
? 'border-primary bg-black'
|
||||||
|
: 'border-black border-opacity-10',
|
||||||
|
].join(' ')}
|
||||||
|
/>
|
||||||
|
{size[0] === 'Original' ? (
|
||||||
|
<span>{`${size[1]}(${originalSize})`}</span>
|
||||||
|
) : (
|
||||||
|
<span>{size[1]}</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</RadioGroup.Option>
|
||||||
|
))}
|
||||||
|
</RadioGroup>
|
||||||
|
)
|
||||||
|
}
|
@ -1271,6 +1271,11 @@
|
|||||||
dependencies:
|
dependencies:
|
||||||
"@hapi/hoek" "^8.3.0"
|
"@hapi/hoek" "^8.3.0"
|
||||||
|
|
||||||
|
"@headlessui/react@^1.4.2":
|
||||||
|
version "1.4.2"
|
||||||
|
resolved "https://registry.npmmirror.com/@headlessui/react/download/@headlessui/react-1.4.2.tgz#87e264f190dbebbf8dfdd900530da973dad24576"
|
||||||
|
integrity sha512-N8tv7kLhg9qGKBkVdtg572BvKvWhmiudmeEpOCyNwzOsZHCXBtl8AazGikIfUS+vBoub20Fse3BjawXDVPPdug==
|
||||||
|
|
||||||
"@heroicons/react@^1.0.4":
|
"@heroicons/react@^1.0.4":
|
||||||
version "1.0.4"
|
version "1.0.4"
|
||||||
resolved "https://registry.yarnpkg.com/@heroicons/react/-/react-1.0.4.tgz#11847eb2ea5510419d7ada9ff150a33af0ad0863"
|
resolved "https://registry.yarnpkg.com/@heroicons/react/-/react-1.0.4.tgz#11847eb2ea5510419d7ada9ff150a33af0ad0863"
|
||||||
|
@ -40,21 +40,42 @@ def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
|
|||||||
return image_bytes
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
def load_img(img_bytes, gray: bool = False, norm: bool = True):
|
def load_img(img_bytes, gray: bool = False):
|
||||||
nparr = np.frombuffer(img_bytes, np.uint8)
|
nparr = np.frombuffer(img_bytes, np.uint8)
|
||||||
if gray:
|
if gray:
|
||||||
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)[:, :, np.newaxis]
|
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
||||||
else:
|
else:
|
||||||
np_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
||||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
|
||||||
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
|
||||||
if norm:
|
else:
|
||||||
np_img = np.transpose(np_img, (2, 0, 1))
|
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||||
np_img = np_img.astype("float32") / 255
|
|
||||||
|
|
||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
|
def norm_img(np_img):
|
||||||
|
if len(np_img.shape) == 2:
|
||||||
|
np_img = np_img[:, :, np.newaxis]
|
||||||
|
np_img = np.transpose(np_img, (2, 0, 1))
|
||||||
|
np_img = np_img.astype("float32") / 255
|
||||||
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_max_size(
|
||||||
|
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
||||||
|
) -> np.ndarray:
|
||||||
|
# Resize image's longer size to size_limit if longer size larger than size_limit
|
||||||
|
h, w = np_img.shape[:2]
|
||||||
|
if max(h, w) > size_limit:
|
||||||
|
ratio = size_limit / max(h, w)
|
||||||
|
new_w = int(w * ratio + 0.5)
|
||||||
|
new_h = int(h * ratio + 0.5)
|
||||||
|
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
||||||
|
else:
|
||||||
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
def pad_img_to_modulo(img, mod):
|
def pad_img_to_modulo(img, mod):
|
||||||
channels, height, width = img.shape
|
channels, height, width = img.shape
|
||||||
out_height = ceil_modulo(height, mod)
|
out_height = ceil_modulo(height, mod)
|
||||||
|
44
main.py
44
main.py
@ -4,6 +4,8 @@ import io
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
from distutils.util import strtobool
|
||||||
|
from typing import Union
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -13,6 +15,8 @@ from flask_cors import CORS
|
|||||||
from lama_cleaner.helper import (
|
from lama_cleaner.helper import (
|
||||||
download_model,
|
download_model,
|
||||||
load_img,
|
load_img,
|
||||||
|
norm_img,
|
||||||
|
resize_max_size,
|
||||||
numpy_to_bytes,
|
numpy_to_bytes,
|
||||||
pad_img_to_modulo,
|
pad_img_to_modulo,
|
||||||
)
|
)
|
||||||
@ -43,13 +47,38 @@ device = None
|
|||||||
def process():
|
def process():
|
||||||
input = request.files
|
input = request.files
|
||||||
image = load_img(input["image"].read())
|
image = load_img(input["image"].read())
|
||||||
|
original_shape = image.shape
|
||||||
|
interpolation = cv2.INTER_CUBIC
|
||||||
|
|
||||||
|
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
||||||
|
if size_limit == "Original":
|
||||||
|
size_limit = max(image.shape)
|
||||||
|
else:
|
||||||
|
size_limit = int(size_limit)
|
||||||
|
|
||||||
|
print(f"Origin image shape: {original_shape}")
|
||||||
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||||
|
print(f"Resized image shape: {image.shape}")
|
||||||
|
image = norm_img(image)
|
||||||
|
|
||||||
mask = load_img(input["mask"].read(), gray=True)
|
mask = load_img(input["mask"].read(), gray=True)
|
||||||
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||||
|
mask = norm_img(mask)
|
||||||
|
|
||||||
res_np_img = run(image, mask)
|
res_np_img = run(image, mask)
|
||||||
|
|
||||||
|
# resize to original size
|
||||||
|
res_np_img = cv2.resize(
|
||||||
|
res_np_img,
|
||||||
|
dsize=(original_shape[1], original_shape[0]),
|
||||||
|
interpolation=interpolation,
|
||||||
|
)
|
||||||
|
|
||||||
return send_file(
|
return send_file(
|
||||||
io.BytesIO(numpy_to_bytes(res_np_img)),
|
io.BytesIO(numpy_to_bytes(res_np_img)),
|
||||||
mimetype="image/png",
|
mimetype="image/jpeg",
|
||||||
as_attachment=True,
|
as_attachment=True,
|
||||||
attachment_filename="result.png",
|
attachment_filename="result.jpeg",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +90,8 @@ def index():
|
|||||||
def run(image, mask):
|
def run(image, mask):
|
||||||
"""
|
"""
|
||||||
image: [C, H, W]
|
image: [C, H, W]
|
||||||
|
mask: [1, H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
origin_height, origin_width = image.shape[1:]
|
origin_height, origin_width = image.shape[1:]
|
||||||
image = pad_img_to_modulo(image, mod=8)
|
image = pad_img_to_modulo(image, mod=8)
|
||||||
@ -73,13 +104,11 @@ def run(image, mask):
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
inpainted_image = model(image, mask)
|
inpainted_image = model(image, mask)
|
||||||
|
|
||||||
print(
|
print(f"process time: {(time.time() - start)*1000}ms")
|
||||||
f"inpainted image shape: {inpainted_image.shape} process time: {(time.time() - start)*1000}ms"
|
|
||||||
)
|
|
||||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
||||||
return cur_res
|
return cur_res
|
||||||
|
|
||||||
|
|
||||||
@ -87,6 +116,7 @@ def get_args_parser():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", default=8080, type=int)
|
parser.add_argument("--port", default=8080, type=int)
|
||||||
parser.add_argument("--device", default="cuda", type=str)
|
parser.add_argument("--device", default="cuda", type=str)
|
||||||
|
parser.add_argument("--debug", action="store_true")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +128,7 @@ def main():
|
|||||||
model_path = download_model()
|
model_path = download_model()
|
||||||
model = torch.jit.load(model_path, map_location="cpu")
|
model = torch.jit.load(model_path, map_location="cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
app.run(host="0.0.0.0", port=args.port, debug=False)
|
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user