resize image using backend;add resize radio button

frontend resize image will reduce image quality
This commit is contained in:
Qing 2021-11-27 20:37:37 +08:00 committed by Sanster
parent 1c2e7fa559
commit 1e2c8fd348
9 changed files with 163 additions and 144 deletions

View File

@ -4,6 +4,7 @@
"private": true, "private": true,
"proxy": "http://localhost:8080", "proxy": "http://localhost:8080",
"dependencies": { "dependencies": {
"@headlessui/react": "^1.4.2",
"@heroicons/react": "^1.0.4", "@heroicons/react": "^1.0.4",
"@testing-library/jest-dom": "^5.14.1", "@testing-library/jest-dom": "^5.14.1",
"@testing-library/react": "^12.1.2", "@testing-library/react": "^12.1.2",

View File

@ -4,7 +4,6 @@ import { useWindowSize } from 'react-use'
import Button from './components/Button' import Button from './components/Button'
import FileSelect from './components/FileSelect' import FileSelect from './components/FileSelect'
import Editor from './Editor' import Editor from './Editor'
import { resizeImageFile } from './utils'
function App() { function App() {
const [file, setFile] = useState<File>() const [file, setFile] = useState<File>()

View File

@ -1,10 +1,11 @@
import { DownloadIcon, EyeIcon } from '@heroicons/react/outline' import { DownloadIcon, EyeIcon } from '@heroicons/react/outline'
import React, { useCallback, useEffect, useState } from 'react' import React, { useCallback, useEffect, useState } from 'react'
import { useWindowSize } from 'react-use' import { useWindowSize, useLocalStorage } from 'react-use'
import inpaint from './adapters/inpainting' import inpaint from './adapters/inpainting'
import Button from './components/Button' import Button from './components/Button'
import Slider from './components/Slider' import Slider from './components/Slider'
import { downloadImage, loadImage, shareImage, useImage } from './utils' import SizeSelector from './components/SizeSelector'
import { downloadImage, loadImage, useImage } from './utils'
const TOOLBAR_SIZE = 200 const TOOLBAR_SIZE = 200
const BRUSH_COLOR = 'rgba(189, 255, 1, 0.75)' const BRUSH_COLOR = 'rgba(189, 255, 1, 0.75)'
@ -55,6 +56,8 @@ export default function Editor(props: EditorProps) {
const [isInpaintingLoading, setIsInpaintingLoading] = useState(false) const [isInpaintingLoading, setIsInpaintingLoading] = useState(false)
const [showSeparator, setShowSeparator] = useState(false) const [showSeparator, setShowSeparator] = useState(false)
const [scale, setScale] = useState(1) const [scale, setScale] = useState(1)
// ['1080', '2000', 'Original']
const [sizeLimit, setSizeLimit] = useLocalStorage('sizeLimit', '1080')
const windowSize = useWindowSize() const windowSize = useWindowSize()
const draw = useCallback(() => { const draw = useCallback(() => {
@ -144,8 +147,7 @@ export default function Editor(props: EditorProps) {
window.removeEventListener('mouseup', onPointerUp) window.removeEventListener('mouseup', onPointerUp)
refreshCanvasMask() refreshCanvasMask()
try { try {
const start = Date.now() const res = await inpaint(file, maskCanvas.toDataURL(), sizeLimit)
const res = await inpaint(file, maskCanvas.toDataURL())
if (!res) { if (!res) {
throw new Error('empty response') throw new Error('empty response')
} }
@ -221,6 +223,7 @@ export default function Editor(props: EditorProps) {
original.naturalHeight, original.naturalHeight,
original.naturalWidth, original.naturalWidth,
scale, scale,
sizeLimit,
]) ])
const undo = useCallback(() => { const undo = useCallback(() => {
@ -252,12 +255,16 @@ export default function Editor(props: EditorProps) {
}, [renders, undo]) }, [renders, undo])
function download() { function download() {
const base64 = context?.canvas.toDataURL(file.type)
if (!base64) {
throw new Error('could not get canvas data')
}
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1') const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
downloadImage(base64, name) const currRender = renders[renders.length - 1]
downloadImage(currRender.currentSrc, name)
}
const onSizeLimitChange = (_sizeLimit: string) => {
// TODO: clean renders
// if (renders.length !== 0) {
// }
setSizeLimit(_sizeLimit)
} }
return ( return (
@ -337,7 +344,7 @@ export default function Editor(props: EditorProps) {
<div <div
className={[ className={[
'flex items-center w-full max-w-3xl', 'flex items-center w-full max-w-5xl',
'space-x-3 sm:space-x-5', 'space-x-3 sm:space-x-5',
'p-6', 'p-6',
scale !== 1 scale !== 1
@ -345,10 +352,15 @@ export default function Editor(props: EditorProps) {
: 'relative justify-evenly sm:justify-between', : 'relative justify-evenly sm:justify-between',
].join(' ')} ].join(' ')}
> >
<SizeSelector
value={sizeLimit}
onChange={onSizeLimitChange}
originalSize={`${original.naturalWidth}x${original.naturalHeight}`}
/>
<Slider <Slider
label={ label={
<span> <span>
<span className="hidden md:inline">Brush</span> Size <span className="hidden md:inline">Brush</span>
</span> </span>
} }
min={10} min={10}

View File

@ -2,12 +2,23 @@ import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint` export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
export default async function inpaint(imageFile: File, maskBase64: string) { export default async function inpaint(
imageFile: File,
maskBase64: string,
sizeLimit?: string
) {
// 1080, 2000, Original
const fd = new FormData() const fd = new FormData()
fd.append('image', imageFile) fd.append('image', imageFile)
const mask = dataURItoBlob(maskBase64) const mask = dataURItoBlob(maskBase64)
fd.append('mask', mask) fd.append('mask', mask)
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {
fd.append('sizeLimit', sizeLimit)
}
const res = await fetch(API_ENDPOINT, { const res = await fetch(API_ENDPOINT, {
method: 'POST', method: 'POST',
body: fd, body: fd,

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,56 @@
import React from 'react'
import { RadioGroup } from '@headlessui/react'
const sizes = [
['1080', '1080'],
['2000', '2k'],
['Original', 'Original'],
]
type SizeSelectorProps = {
value?: string
originalSize: string
onChange: (value: string) => void
}
export default function SizeSelector(props: SizeSelectorProps) {
const { value, originalSize, onChange } = props
return (
<RadioGroup
className="my-4 flex items-center space-x-2"
value={value}
onChange={onChange}
>
<RadioGroup.Label>Resize</RadioGroup.Label>
{sizes.map(size => (
<RadioGroup.Option key={size[0]} value={size[0]}>
{({ checked }) => (
<div
className={[
checked ? 'bg-gray-200' : 'border-opacity-10',
'border-3 px-2 py-2 rounded-md',
].join(' ')}
>
<div className="flex items-center space-x-4">
<div
className={[
'rounded-full w-5 h-5 border-4 ',
checked
? 'border-primary bg-black'
: 'border-black border-opacity-10',
].join(' ')}
/>
{size[0] === 'Original' ? (
<span>{`${size[1]}(${originalSize})`}</span>
) : (
<span>{size[1]}</span>
)}
</div>
</div>
)}
</RadioGroup.Option>
))}
</RadioGroup>
)
}

View File

@ -1271,6 +1271,11 @@
dependencies: dependencies:
"@hapi/hoek" "^8.3.0" "@hapi/hoek" "^8.3.0"
"@headlessui/react@^1.4.2":
version "1.4.2"
resolved "https://registry.npmmirror.com/@headlessui/react/download/@headlessui/react-1.4.2.tgz#87e264f190dbebbf8dfdd900530da973dad24576"
integrity sha512-N8tv7kLhg9qGKBkVdtg572BvKvWhmiudmeEpOCyNwzOsZHCXBtl8AazGikIfUS+vBoub20Fse3BjawXDVPPdug==
"@heroicons/react@^1.0.4": "@heroicons/react@^1.0.4":
version "1.0.4" version "1.0.4"
resolved "https://registry.yarnpkg.com/@heroicons/react/-/react-1.0.4.tgz#11847eb2ea5510419d7ada9ff150a33af0ad0863" resolved "https://registry.yarnpkg.com/@heroicons/react/-/react-1.0.4.tgz#11847eb2ea5510419d7ada9ff150a33af0ad0863"

View File

@ -40,21 +40,42 @@ def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
return image_bytes return image_bytes
def load_img(img_bytes, gray: bool = False, norm: bool = True): def load_img(img_bytes, gray: bool = False):
nparr = np.frombuffer(img_bytes, np.uint8) nparr = np.frombuffer(img_bytes, np.uint8)
if gray: if gray:
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)[:, :, np.newaxis] np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
else: else:
np_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) if len(np_img.shape) == 3 and np_img.shape[2] == 4:
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
if norm: else:
np_img = np.transpose(np_img, (2, 0, 1)) np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype("float32") / 255
return np_img return np_img
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def resize_max_size(
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
) -> np.ndarray:
# Resize image's longer size to size_limit if longer size larger than size_limit
h, w = np_img.shape[:2]
if max(h, w) > size_limit:
ratio = size_limit / max(h, w)
new_w = int(w * ratio + 0.5)
new_h = int(h * ratio + 0.5)
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
else:
return np_img
def pad_img_to_modulo(img, mod): def pad_img_to_modulo(img, mod):
channels, height, width = img.shape channels, height, width = img.shape
out_height = ceil_modulo(height, mod) out_height = ceil_modulo(height, mod)

44
main.py
View File

@ -4,6 +4,8 @@ import io
import os import os
import time import time
import argparse import argparse
from distutils.util import strtobool
from typing import Union
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
@ -13,6 +15,8 @@ from flask_cors import CORS
from lama_cleaner.helper import ( from lama_cleaner.helper import (
download_model, download_model,
load_img, load_img,
norm_img,
resize_max_size,
numpy_to_bytes, numpy_to_bytes,
pad_img_to_modulo, pad_img_to_modulo,
) )
@ -43,13 +47,38 @@ device = None
def process(): def process():
input = request.files input = request.files
image = load_img(input["image"].read()) image = load_img(input["image"].read())
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image.shape}")
image = norm_img(image)
mask = load_img(input["mask"].read(), gray=True) mask = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask)
res_np_img = run(image, mask) res_np_img = run(image, mask)
# resize to original size
res_np_img = cv2.resize(
res_np_img,
dsize=(original_shape[1], original_shape[0]),
interpolation=interpolation,
)
return send_file( return send_file(
io.BytesIO(numpy_to_bytes(res_np_img)), io.BytesIO(numpy_to_bytes(res_np_img)),
mimetype="image/png", mimetype="image/jpeg",
as_attachment=True, as_attachment=True,
attachment_filename="result.png", attachment_filename="result.jpeg",
) )
@ -61,6 +90,8 @@ def index():
def run(image, mask): def run(image, mask):
""" """
image: [C, H, W] image: [C, H, W]
mask: [1, H, W]
return: BGR IMAGE
""" """
origin_height, origin_width = image.shape[1:] origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8) image = pad_img_to_modulo(image, mod=8)
@ -73,13 +104,11 @@ def run(image, mask):
start = time.time() start = time.time()
inpainted_image = model(image, mask) inpainted_image = model(image, mask)
print( print(f"process time: {(time.time() - start)*1000}ms")
f"inpainted image shape: {inpainted_image.shape} process time: {(time.time() - start)*1000}ms"
)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
return cur_res return cur_res
@ -87,6 +116,7 @@ def get_args_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--debug", action="store_true")
return parser.parse_args() return parser.parse_args()
@ -98,7 +128,7 @@ def main():
model_path = download_model() model_path = download_model()
model = torch.jit.load(model_path, map_location="cpu") model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device) model = model.to(device)
app.run(host="0.0.0.0", port=args.port, debug=False) app.run(host="0.0.0.0", port=args.port, debug=args.debug)
if __name__ == "__main__": if __name__ == "__main__":