()
const [showShortcuts, toggleShowShortcuts] = useToggle(false)
const windowSize = useWindowSize()
+ const userInputImage = useInputImage()
+
+ useEffect(() => {
+ setFile(userInputImage)
+ }, [userInputImage])
return (
diff --git a/lama_cleaner/app/src/components/hooks/useInputImage.js b/lama_cleaner/app/src/components/hooks/useInputImage.js
new file mode 100644
index 0000000..0c566f1
--- /dev/null
+++ b/lama_cleaner/app/src/components/hooks/useInputImage.js
@@ -0,0 +1,22 @@
+import { useCallback, useEffect, useState } from 'react'
+
+export default function useInputImage() {
+ const [inputImage, setInputImage] = useState()
+
+ const fetchInputImage = useCallback(() => {
+ fetch('/inputimage')
+ .then(res => res.blob())
+ .then(data => {
+ if (data && data.type.startsWith('image')) {
+ const userInput = new File([data], 'inputImage')
+ setInputImage(userInput)
+ }
+ })
+ }, [setInputImage])
+
+ useEffect(() => {
+ fetchInputImage()
+ }, [])
+
+ return inputImage
+}
diff --git a/main.py b/main.py
index 7424b02..fc79eff 100644
--- a/main.py
+++ b/main.py
@@ -5,6 +5,7 @@ import io
import multiprocessing
import os
import time
+import imghdr
from typing import Union
import cv2
@@ -98,8 +99,24 @@ def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
+@app.route('/inputimage')
+def set_input_photo():
+ filename = os.path.join(os.path.dirname(__file__), input_image)
+ if (os.path.exists(filename)):
+ if (imghdr.what(filename) is not None):
+ with open(filename, 'rb') as f:
+ byte_im = f.read()
+ return send_file(io.BytesIO(byte_im), mimetype='image/jpeg')
+ else:
+ return 'Invalid Input'
+ else:
+ return 'Invalid Input'
+
+
def get_args_parser():
parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input", default='', type=str, help="Path to image you want to load by default")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument("--crop-trigger-size", nargs=2, type=int,
@@ -128,11 +145,16 @@ def get_args_parser():
def main():
global model
global device
+ global input_image
+
args = get_args_parser()
device = torch.device(args.device)
+ input_image = args.input
+
if args.model == "lama":
- model = LaMa(crop_trigger_size=args.crop_trigger_size, crop_margin=args.crop_margin, device=device)
+ model = LaMa(crop_trigger_size=args.crop_trigger_size,
+ crop_margin=args.crop_margin, device=device)
elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps)
else: