return correct file ext/mimetype

This commit is contained in:
Sanster 2022-04-09 07:23:33 +08:00
parent 98fa52ba08
commit caed45b520
2 changed files with 59 additions and 32 deletions

View File

@ -30,8 +30,8 @@ def ceil_modulo(x, mod):
return (x // mod + 1) * mod return (x // mod + 1) * mod
def numpy_to_bytes(image_numpy: np.ndarray) -> bytes: def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(".jpg", image_numpy)[1] data = cv2.imencode(f".{ext}", image_numpy)[1]
image_bytes = data.tobytes() image_bytes = data.tobytes()
return image_bytes return image_bytes
@ -92,7 +92,9 @@ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
""" """
height, width = mask.shape[1:] height, width = mask.shape[1:]
_, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0) _, thresh = cv2.threshold(
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = [] boxes = []

83
main.py
View File

@ -44,8 +44,7 @@ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"): if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
"./lama_cleaner/app/build")
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
@ -56,11 +55,20 @@ device = None
input_image_path: str = None input_image_path: str = None
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
@app.route("/inpaint", methods=["POST"]) @app.route("/inpaint", methods=["POST"])
def process(): def process():
input = request.files input = request.files
# RGB # RGB
image = load_img(input["image"].read()) origin_image_bytes = input["image"].read()
image = load_img(origin_image_bytes)
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -71,14 +79,12 @@ def process():
size_limit = int(size_limit) size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}") print(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
interpolation=interpolation)
print(f"Resized image shape: {image.shape}") print(f"Resized image shape: {image.shape}")
image = norm_img(image) image = norm_img(image)
mask = load_img(input["mask"].read(), gray=True) mask = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
interpolation=interpolation)
mask = norm_img(mask) mask = norm_img(mask)
start = time.time() start = time.time()
@ -87,11 +93,10 @@ def process():
torch.cuda.empty_cache() torch.cuda.empty_cache()
ext = get_image_ext(origin_image_bytes)
return send_file( return send_file(
io.BytesIO(numpy_to_bytes(res_np_img)), io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype="image/jpeg", mimetype=f"image/{ext}",
as_attachment=True,
attachment_filename="result.jpeg",
) )
@ -100,29 +105,42 @@ def index():
return send_file(os.path.join(BUILD_DIR, "index.html")) return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route('/inputimage') @app.route("/inputimage")
def set_input_photo(): def set_input_photo():
if input_image_path: if input_image_path:
with open(input_image_path, 'rb') as f: with open(input_image_path, "rb") as f:
image_in_bytes = f.read() image_in_bytes = f.read()
return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg') return send_file(
io.BytesIO(image_in_bytes),
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else: else:
return 'No Input Image' return "No Input Image"
def get_args_parser(): def get_args_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--input", type=str, help="Path to image you want to load by default") "--input", type=str, help="Path to image you want to load by default"
)
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument("--crop-trigger-size", default=[2042, 2042], nargs=2, type=int, parser.add_argument(
help="If image size large then crop-trigger-size, " "--crop-trigger-size",
"crop each area from original image to do inference." default=[2042, 2042],
"Mainly for performance and memory reasons" nargs=2,
"Only for lama") type=int,
parser.add_argument("--crop-margin", type=int, default=256, help="If image size large then crop-trigger-size, "
help="Margin around bounding box of painted stroke when crop mode triggered") "crop each area from original image to do inference."
"Mainly for performance and memory reasons"
"Only for lama",
)
parser.add_argument(
"--crop-margin",
type=int,
default=256,
help="Margin around bounding box of painted stroke when crop mode triggered",
)
parser.add_argument( parser.add_argument(
"--ldm-steps", "--ldm-steps",
default=50, default=50,
@ -131,10 +149,14 @@ def get_args_parser():
"The larger the value, the better the result, but it will be more time-consuming", "The larger the value, the better the result, but it will be more time-consuming",
) )
parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
help="Launch as desktop app") parser.add_argument(
parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int, "--gui-size",
help="Set window size for GUI") default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@ -157,8 +179,11 @@ def main():
input_image_path = args.input input_image_path = args.input
if args.model == "lama": if args.model == "lama":
model = LaMa(crop_trigger_size=args.crop_trigger_size, model = LaMa(
crop_margin=args.crop_margin, device=device) crop_trigger_size=args.crop_trigger_size,
crop_margin=args.crop_margin,
device=device,
)
elif args.model == "ldm": elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps) model = LDM(device, steps=args.ldm_steps)
else: else: