check --input before start server

This commit is contained in:
Sanster 2022-03-27 13:37:26 +08:00
parent 0cc17ea322
commit 705e12d02d

27
main.py
View File

@ -53,6 +53,7 @@ CORS(app)
model = None model = None
device = None device = None
input_image_path: str = None
@app.route("/inpaint", methods=["POST"]) @app.route("/inpaint", methods=["POST"])
@ -101,13 +102,10 @@ def index():
@app.route('/inputimage') @app.route('/inputimage')
def set_input_photo(): def set_input_photo():
if input_image: if input_image_path:
input_file = os.path.join(os.path.dirname(__file__), input_image) with open(input_image_path, 'rb') as f:
if os.path.exists(input_file): # Check if file exists image_in_bytes = f.read()
if imghdr.what(input_file) is not None: # Check if file is image return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg')
with open(input_file, 'rb') as f:
image_in_bytes = f.read()
return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg')
else: else:
return 'No Input Image' return 'No Input Image'
@ -138,18 +136,25 @@ def get_args_parser():
parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int, parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int,
help="Set window size for GUI") help="Set window size for GUI")
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
return parser.parse_args()
args = parser.parse_args()
if args.input is not None:
if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists")
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
return args
def main(): def main():
global model global model
global device global device
global input_image global input_image_path
args = get_args_parser() args = get_args_parser()
device = torch.device(args.device) device = torch.device(args.device)
input_image_path = args.input
input_image = args.input
if args.model == "lama": if args.model == "lama":
model = LaMa(crop_trigger_size=args.crop_trigger_size, model = LaMa(crop_trigger_size=args.crop_trigger_size,