return correct file ext/mimetype
This commit is contained in:
parent
98fa52ba08
commit
caed45b520
@ -30,8 +30,8 @@ def ceil_modulo(x, mod):
|
|||||||
return (x // mod + 1) * mod
|
return (x // mod + 1) * mod
|
||||||
|
|
||||||
|
|
||||||
def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
|
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||||
data = cv2.imencode(".jpg", image_numpy)[1]
|
data = cv2.imencode(f".{ext}", image_numpy)[1]
|
||||||
image_bytes = data.tobytes()
|
image_bytes = data.tobytes()
|
||||||
return image_bytes
|
return image_bytes
|
||||||
|
|
||||||
@ -92,7 +92,9 @@ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
height, width = mask.shape[1:]
|
height, width = mask.shape[1:]
|
||||||
_, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0)
|
_, thresh = cv2.threshold(
|
||||||
|
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
|
||||||
|
)
|
||||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
boxes = []
|
boxes = []
|
||||||
|
83
main.py
83
main.py
@ -44,8 +44,7 @@ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
|||||||
if os.environ.get("CACHE_DIR"):
|
if os.environ.get("CACHE_DIR"):
|
||||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
||||||
|
|
||||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR",
|
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
||||||
"./lama_cleaner/app/build")
|
|
||||||
|
|
||||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||||
app.config["JSON_AS_ASCII"] = False
|
app.config["JSON_AS_ASCII"] = False
|
||||||
@ -56,11 +55,20 @@ device = None
|
|||||||
input_image_path: str = None
|
input_image_path: str = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_ext(img_bytes):
|
||||||
|
w = imghdr.what("", img_bytes)
|
||||||
|
if w is None:
|
||||||
|
w = "jpeg"
|
||||||
|
return w
|
||||||
|
|
||||||
|
|
||||||
@app.route("/inpaint", methods=["POST"])
|
@app.route("/inpaint", methods=["POST"])
|
||||||
def process():
|
def process():
|
||||||
input = request.files
|
input = request.files
|
||||||
# RGB
|
# RGB
|
||||||
image = load_img(input["image"].read())
|
origin_image_bytes = input["image"].read()
|
||||||
|
|
||||||
|
image = load_img(origin_image_bytes)
|
||||||
original_shape = image.shape
|
original_shape = image.shape
|
||||||
interpolation = cv2.INTER_CUBIC
|
interpolation = cv2.INTER_CUBIC
|
||||||
|
|
||||||
@ -71,14 +79,12 @@ def process():
|
|||||||
size_limit = int(size_limit)
|
size_limit = int(size_limit)
|
||||||
|
|
||||||
print(f"Origin image shape: {original_shape}")
|
print(f"Origin image shape: {original_shape}")
|
||||||
image = resize_max_size(image, size_limit=size_limit,
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||||
interpolation=interpolation)
|
|
||||||
print(f"Resized image shape: {image.shape}")
|
print(f"Resized image shape: {image.shape}")
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
|
|
||||||
mask = load_img(input["mask"].read(), gray=True)
|
mask = load_img(input["mask"].read(), gray=True)
|
||||||
mask = resize_max_size(mask, size_limit=size_limit,
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||||
interpolation=interpolation)
|
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -87,11 +93,10 @@ def process():
|
|||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
ext = get_image_ext(origin_image_bytes)
|
||||||
return send_file(
|
return send_file(
|
||||||
io.BytesIO(numpy_to_bytes(res_np_img)),
|
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||||
mimetype="image/jpeg",
|
mimetype=f"image/{ext}",
|
||||||
as_attachment=True,
|
|
||||||
attachment_filename="result.jpeg",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -100,29 +105,42 @@ def index():
|
|||||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/inputimage')
|
@app.route("/inputimage")
|
||||||
def set_input_photo():
|
def set_input_photo():
|
||||||
if input_image_path:
|
if input_image_path:
|
||||||
with open(input_image_path, 'rb') as f:
|
with open(input_image_path, "rb") as f:
|
||||||
image_in_bytes = f.read()
|
image_in_bytes = f.read()
|
||||||
return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg')
|
return send_file(
|
||||||
|
io.BytesIO(image_in_bytes),
|
||||||
|
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return 'No Input Image'
|
return "No Input Image"
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
def get_args_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input", type=str, help="Path to image you want to load by default")
|
"--input", type=str, help="Path to image you want to load by default"
|
||||||
|
)
|
||||||
parser.add_argument("--port", default=8080, type=int)
|
parser.add_argument("--port", default=8080, type=int)
|
||||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||||
parser.add_argument("--crop-trigger-size", default=[2042, 2042], nargs=2, type=int,
|
parser.add_argument(
|
||||||
help="If image size large then crop-trigger-size, "
|
"--crop-trigger-size",
|
||||||
"crop each area from original image to do inference."
|
default=[2042, 2042],
|
||||||
"Mainly for performance and memory reasons"
|
nargs=2,
|
||||||
"Only for lama")
|
type=int,
|
||||||
parser.add_argument("--crop-margin", type=int, default=256,
|
help="If image size large then crop-trigger-size, "
|
||||||
help="Margin around bounding box of painted stroke when crop mode triggered")
|
"crop each area from original image to do inference."
|
||||||
|
"Mainly for performance and memory reasons"
|
||||||
|
"Only for lama",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop-margin",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Margin around bounding box of painted stroke when crop mode triggered",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ldm-steps",
|
"--ldm-steps",
|
||||||
default=50,
|
default=50,
|
||||||
@ -131,10 +149,14 @@ def get_args_parser():
|
|||||||
"The larger the value, the better the result, but it will be more time-consuming",
|
"The larger the value, the better the result, but it will be more time-consuming",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", default="cuda", type=str)
|
parser.add_argument("--device", default="cuda", type=str)
|
||||||
parser.add_argument("--gui", action="store_true",
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||||
help="Launch as desktop app")
|
parser.add_argument(
|
||||||
parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int,
|
"--gui-size",
|
||||||
help="Set window size for GUI")
|
default=[1600, 1000],
|
||||||
|
nargs=2,
|
||||||
|
type=int,
|
||||||
|
help="Set window size for GUI",
|
||||||
|
)
|
||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -157,8 +179,11 @@ def main():
|
|||||||
input_image_path = args.input
|
input_image_path = args.input
|
||||||
|
|
||||||
if args.model == "lama":
|
if args.model == "lama":
|
||||||
model = LaMa(crop_trigger_size=args.crop_trigger_size,
|
model = LaMa(
|
||||||
crop_margin=args.crop_margin, device=device)
|
crop_trigger_size=args.crop_trigger_size,
|
||||||
|
crop_margin=args.crop_margin,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
elif args.model == "ldm":
|
elif args.model == "ldm":
|
||||||
model = LDM(device, steps=args.ldm_steps)
|
model = LDM(device, steps=args.ldm_steps)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user