diff --git a/lama_cleaner/app/package.json b/lama_cleaner/app/package.json index 690c164..bd80a21 100644 --- a/lama_cleaner/app/package.json +++ b/lama_cleaner/app/package.json @@ -4,6 +4,7 @@ "private": true, "proxy": "http://localhost:8080", "dependencies": { + "@headlessui/react": "^1.4.2", "@heroicons/react": "^1.0.4", "@testing-library/jest-dom": "^5.14.1", "@testing-library/react": "^12.1.2", diff --git a/lama_cleaner/app/src/App.tsx b/lama_cleaner/app/src/App.tsx index 27a118a..7fd03c1 100644 --- a/lama_cleaner/app/src/App.tsx +++ b/lama_cleaner/app/src/App.tsx @@ -4,7 +4,6 @@ import { useWindowSize } from 'react-use' import Button from './components/Button' import FileSelect from './components/FileSelect' import Editor from './Editor' -import { resizeImageFile } from './utils' function App() { const [file, setFile] = useState() diff --git a/lama_cleaner/app/src/Editor.tsx b/lama_cleaner/app/src/Editor.tsx index 949527f..e22a09c 100644 --- a/lama_cleaner/app/src/Editor.tsx +++ b/lama_cleaner/app/src/Editor.tsx @@ -1,10 +1,11 @@ import { DownloadIcon, EyeIcon } from '@heroicons/react/outline' import React, { useCallback, useEffect, useState } from 'react' -import { useWindowSize } from 'react-use' +import { useWindowSize, useLocalStorage } from 'react-use' import inpaint from './adapters/inpainting' import Button from './components/Button' import Slider from './components/Slider' -import { downloadImage, loadImage, shareImage, useImage } from './utils' +import SizeSelector from './components/SizeSelector' +import { downloadImage, loadImage, useImage } from './utils' const TOOLBAR_SIZE = 200 const BRUSH_COLOR = 'rgba(189, 255, 1, 0.75)' @@ -55,6 +56,8 @@ export default function Editor(props: EditorProps) { const [isInpaintingLoading, setIsInpaintingLoading] = useState(false) const [showSeparator, setShowSeparator] = useState(false) const [scale, setScale] = useState(1) + // ['1080', '2000', 'Original'] + const [sizeLimit, setSizeLimit] = useLocalStorage('sizeLimit', '1080') const windowSize = useWindowSize() const draw = useCallback(() => { @@ -144,8 +147,7 @@ export default function Editor(props: EditorProps) { window.removeEventListener('mouseup', onPointerUp) refreshCanvasMask() try { - const start = Date.now() - const res = await inpaint(file, maskCanvas.toDataURL()) + const res = await inpaint(file, maskCanvas.toDataURL(), sizeLimit) if (!res) { throw new Error('empty response') } @@ -221,6 +223,7 @@ export default function Editor(props: EditorProps) { original.naturalHeight, original.naturalWidth, scale, + sizeLimit, ]) const undo = useCallback(() => { @@ -252,12 +255,16 @@ export default function Editor(props: EditorProps) { }, [renders, undo]) function download() { - const base64 = context?.canvas.toDataURL(file.type) - if (!base64) { - throw new Error('could not get canvas data') - } const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1') - downloadImage(base64, name) + const currRender = renders[renders.length - 1] + downloadImage(currRender.currentSrc, name) + } + + const onSizeLimitChange = (_sizeLimit: string) => { + // TODO: clean renders + // if (renders.length !== 0) { + // } + setSizeLimit(_sizeLimit) } return ( @@ -337,7 +344,7 @@ export default function Editor(props: EditorProps) {
+ - Brush Size + Brush } min={10} diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index c9d6353..17c9b6b 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -2,12 +2,23 @@ import { dataURItoBlob } from '../utils' export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint` -export default async function inpaint(imageFile: File, maskBase64: string) { +export default async function inpaint( + imageFile: File, + maskBase64: string, + sizeLimit?: string +) { + // 1080, 2000, Original const fd = new FormData() fd.append('image', imageFile) const mask = dataURItoBlob(maskBase64) fd.append('mask', mask) + if (sizeLimit === undefined) { + fd.append('sizeLimit', '1080') + } else { + fd.append('sizeLimit', sizeLimit) + } + const res = await fetch(API_ENDPOINT, { method: 'POST', body: fd, diff --git a/lama_cleaner/app/src/components/Logo.tsx b/lama_cleaner/app/src/components/Logo.tsx deleted file mode 100644 index bf5699d..0000000 --- a/lama_cleaner/app/src/components/Logo.tsx +++ /dev/null @@ -1,116 +0,0 @@ -import React from 'react' - -interface LogoProps { - className?: string -} - -export default function Logo(props: LogoProps) { - const { className } = props - return ( - - - - - - - - - - - - - - - - - - - - - - - - - - - ) -} diff --git a/lama_cleaner/app/src/components/SizeSelector.tsx b/lama_cleaner/app/src/components/SizeSelector.tsx new file mode 100644 index 0000000..447b72e --- /dev/null +++ b/lama_cleaner/app/src/components/SizeSelector.tsx @@ -0,0 +1,56 @@ +import React from 'react' +import { RadioGroup } from '@headlessui/react' + +const sizes = [ + ['1080', '1080'], + ['2000', '2k'], + ['Original', 'Original'], +] + +type SizeSelectorProps = { + value?: string + originalSize: string + onChange: (value: string) => void +} + +export default function SizeSelector(props: SizeSelectorProps) { + const { value, originalSize, onChange } = props + + return ( + + Resize + {sizes.map(size => ( + + {({ checked }) => ( +
+
+
+ {size[0] === 'Original' ? ( + {`${size[1]}(${originalSize})`} + ) : ( + {size[1]} + )} +
+
+ )} + + ))} + + ) +} diff --git a/lama_cleaner/app/yarn.lock b/lama_cleaner/app/yarn.lock index 4ffb4b4..e2116e0 100644 --- a/lama_cleaner/app/yarn.lock +++ b/lama_cleaner/app/yarn.lock @@ -1271,6 +1271,11 @@ dependencies: "@hapi/hoek" "^8.3.0" +"@headlessui/react@^1.4.2": + version "1.4.2" + resolved "https://registry.npmmirror.com/@headlessui/react/download/@headlessui/react-1.4.2.tgz#87e264f190dbebbf8dfdd900530da973dad24576" + integrity sha512-N8tv7kLhg9qGKBkVdtg572BvKvWhmiudmeEpOCyNwzOsZHCXBtl8AazGikIfUS+vBoub20Fse3BjawXDVPPdug== + "@heroicons/react@^1.0.4": version "1.0.4" resolved "https://registry.yarnpkg.com/@heroicons/react/-/react-1.0.4.tgz#11847eb2ea5510419d7ada9ff150a33af0ad0863" diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 75fd85a..cbf8b92 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -40,21 +40,42 @@ def numpy_to_bytes(image_numpy: np.ndarray) -> bytes: return image_bytes -def load_img(img_bytes, gray: bool = False, norm: bool = True): +def load_img(img_bytes, gray: bool = False): nparr = np.frombuffer(img_bytes, np.uint8) if gray: - np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)[:, :, np.newaxis] + np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) else: - np_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) - - if norm: - np_img = np.transpose(np_img, (2, 0, 1)) - np_img = np_img.astype("float32") / 255 + np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) + if len(np_img.shape) == 3 and np_img.shape[2] == 4: + np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB) + else: + np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) return np_img +def norm_img(np_img): + if len(np_img.shape) == 2: + np_img = np_img[:, :, np.newaxis] + np_img = np.transpose(np_img, (2, 0, 1)) + np_img = np_img.astype("float32") / 255 + return np_img + + +def resize_max_size( + np_img, size_limit: int, interpolation=cv2.INTER_CUBIC +) -> np.ndarray: + # Resize image's longer size to size_limit if longer size larger than size_limit + h, w = np_img.shape[:2] + if max(h, w) > size_limit: + ratio = size_limit / max(h, w) + new_w = int(w * ratio + 0.5) + new_h = int(h * ratio + 0.5) + return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) + else: + return np_img + + def pad_img_to_modulo(img, mod): channels, height, width = img.shape out_height = ceil_modulo(height, mod) diff --git a/main.py b/main.py index 1d41ff3..7500240 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,8 @@ import io import os import time import argparse +from distutils.util import strtobool +from typing import Union import cv2 import numpy as np import torch @@ -13,6 +15,8 @@ from flask_cors import CORS from lama_cleaner.helper import ( download_model, load_img, + norm_img, + resize_max_size, numpy_to_bytes, pad_img_to_modulo, ) @@ -43,13 +47,38 @@ device = None def process(): input = request.files image = load_img(input["image"].read()) + original_shape = image.shape + interpolation = cv2.INTER_CUBIC + + size_limit: Union[int, str] = request.form.get("sizeLimit", "1080") + if size_limit == "Original": + size_limit = max(image.shape) + else: + size_limit = int(size_limit) + + print(f"Origin image shape: {original_shape}") + image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) + print(f"Resized image shape: {image.shape}") + image = norm_img(image) + mask = load_img(input["mask"].read(), gray=True) + mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) + mask = norm_img(mask) + res_np_img = run(image, mask) + + # resize to original size + res_np_img = cv2.resize( + res_np_img, + dsize=(original_shape[1], original_shape[0]), + interpolation=interpolation, + ) + return send_file( io.BytesIO(numpy_to_bytes(res_np_img)), - mimetype="image/png", + mimetype="image/jpeg", as_attachment=True, - attachment_filename="result.png", + attachment_filename="result.jpeg", ) @@ -61,6 +90,8 @@ def index(): def run(image, mask): """ image: [C, H, W] + mask: [1, H, W] + return: BGR IMAGE """ origin_height, origin_width = image.shape[1:] image = pad_img_to_modulo(image, mod=8) @@ -73,13 +104,11 @@ def run(image, mask): start = time.time() inpainted_image = model(image, mask) - print( - f"inpainted image shape: {inpainted_image.shape} process time: {(time.time() - start)*1000}ms" - ) + print(f"process time: {(time.time() - start)*1000}ms") cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") - cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) return cur_res @@ -87,6 +116,7 @@ def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--port", default=8080, type=int) parser.add_argument("--device", default="cuda", type=str) + parser.add_argument("--debug", action="store_true") return parser.parse_args() @@ -98,7 +128,7 @@ def main(): model_path = download_model() model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) - app.run(host="0.0.0.0", port=args.port, debug=False) + app.run(host="0.0.0.0", port=args.port, debug=args.debug) if __name__ == "__main__":