From 705e12d02dc1a12c5552bb2f900c74bed0aea7f1 Mon Sep 17 00:00:00 2001 From: Sanster Date: Sun, 27 Mar 2022 13:37:26 +0800 Subject: [PATCH] check --input before start server --- main.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 0716a32..08b4e42 100644 --- a/main.py +++ b/main.py @@ -53,6 +53,7 @@ CORS(app) model = None device = None +input_image_path: str = None @app.route("/inpaint", methods=["POST"]) @@ -101,13 +102,10 @@ def index(): @app.route('/inputimage') def set_input_photo(): - if input_image: - input_file = os.path.join(os.path.dirname(__file__), input_image) - if os.path.exists(input_file): # Check if file exists - if imghdr.what(input_file) is not None: # Check if file is image - with open(input_file, 'rb') as f: - image_in_bytes = f.read() - return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg') + if input_image_path: + with open(input_image_path, 'rb') as f: + image_in_bytes = f.read() + return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg') else: return 'No Input Image' @@ -138,18 +136,25 @@ def get_args_parser(): parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int, help="Set window size for GUI") parser.add_argument("--debug", action="store_true") - return parser.parse_args() + + args = parser.parse_args() + if args.input is not None: + if not os.path.exists(args.input): + parser.error(f"invalid --input: {args.input} not exists") + if imghdr.what(args.input) is None: + parser.error(f"invalid --input: {args.input} is not a valid image file") + + return args def main(): global model global device - global input_image + global input_image_path args = get_args_parser() device = torch.device(args.device) - - input_image = args.input + input_image_path = args.input if args.model == "lama": model = LaMa(crop_trigger_size=args.crop_trigger_size,