return correct file ext/mimetype

This commit is contained in:
Sanster 2022-04-09 07:23:33 +08:00
parent 98fa52ba08
commit caed45b520
2 changed files with 59 additions and 32 deletions

View File

@ -30,8 +30,8 @@ def ceil_modulo(x, mod):
return (x // mod + 1) * mod return (x // mod + 1) * mod
def numpy_to_bytes(image_numpy: np.ndarray) -> bytes: def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(".jpg", image_numpy)[1] data = cv2.imencode(f".{ext}", image_numpy)[1]
image_bytes = data.tobytes() image_bytes = data.tobytes()
return image_bytes return image_bytes
@ -92,7 +92,9 @@ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
""" """
height, width = mask.shape[1:] height, width = mask.shape[1:]
_, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0) _, thresh = cv2.threshold(
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = [] boxes = []

77
main.py
View File

@ -44,8 +44,7 @@ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"): if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
"./lama_cleaner/app/build")
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
@ -56,11 +55,20 @@ device = None
input_image_path: str = None input_image_path: str = None
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
@app.route("/inpaint", methods=["POST"]) @app.route("/inpaint", methods=["POST"])
def process(): def process():
input = request.files input = request.files
# RGB # RGB
image = load_img(input["image"].read()) origin_image_bytes = input["image"].read()
image = load_img(origin_image_bytes)
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -71,14 +79,12 @@ def process():
size_limit = int(size_limit) size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}") print(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
interpolation=interpolation)
print(f"Resized image shape: {image.shape}") print(f"Resized image shape: {image.shape}")
image = norm_img(image) image = norm_img(image)
mask = load_img(input["mask"].read(), gray=True) mask = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
interpolation=interpolation)
mask = norm_img(mask) mask = norm_img(mask)
start = time.time() start = time.time()
@ -87,11 +93,10 @@ def process():
torch.cuda.empty_cache() torch.cuda.empty_cache()
ext = get_image_ext(origin_image_bytes)
return send_file( return send_file(
io.BytesIO(numpy_to_bytes(res_np_img)), io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype="image/jpeg", mimetype=f"image/{ext}",
as_attachment=True,
attachment_filename="result.jpeg",
) )
@ -100,29 +105,42 @@ def index():
return send_file(os.path.join(BUILD_DIR, "index.html")) return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route('/inputimage') @app.route("/inputimage")
def set_input_photo(): def set_input_photo():
if input_image_path: if input_image_path:
with open(input_image_path, 'rb') as f: with open(input_image_path, "rb") as f:
image_in_bytes = f.read() image_in_bytes = f.read()
return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg') return send_file(
io.BytesIO(image_in_bytes),
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else: else:
return 'No Input Image' return "No Input Image"
def get_args_parser(): def get_args_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--input", type=str, help="Path to image you want to load by default") "--input", type=str, help="Path to image you want to load by default"
)
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument("--crop-trigger-size", default=[2042, 2042], nargs=2, type=int, parser.add_argument(
"--crop-trigger-size",
default=[2042, 2042],
nargs=2,
type=int,
help="If image size large then crop-trigger-size, " help="If image size large then crop-trigger-size, "
"crop each area from original image to do inference." "crop each area from original image to do inference."
"Mainly for performance and memory reasons" "Mainly for performance and memory reasons"
"Only for lama") "Only for lama",
parser.add_argument("--crop-margin", type=int, default=256, )
help="Margin around bounding box of painted stroke when crop mode triggered") parser.add_argument(
"--crop-margin",
type=int,
default=256,
help="Margin around bounding box of painted stroke when crop mode triggered",
)
parser.add_argument( parser.add_argument(
"--ldm-steps", "--ldm-steps",
default=50, default=50,
@ -131,10 +149,14 @@ def get_args_parser():
"The larger the value, the better the result, but it will be more time-consuming", "The larger the value, the better the result, but it will be more time-consuming",
) )
parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
help="Launch as desktop app") parser.add_argument(
parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int, "--gui-size",
help="Set window size for GUI") default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@ -157,8 +179,11 @@ def main():
input_image_path = args.input input_image_path = args.input
if args.model == "lama": if args.model == "lama":
model = LaMa(crop_trigger_size=args.crop_trigger_size, model = LaMa(
crop_margin=args.crop_margin, device=device) crop_trigger_size=args.crop_trigger_size,
crop_margin=args.crop_margin,
device=device,
)
elif args.model == "ldm": elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps) model = LDM(device, steps=args.ldm_steps)
else: else: