check --input before start server
This commit is contained in:
parent
0cc17ea322
commit
705e12d02d
23
main.py
23
main.py
@ -53,6 +53,7 @@ CORS(app)
|
|||||||
|
|
||||||
model = None
|
model = None
|
||||||
device = None
|
device = None
|
||||||
|
input_image_path: str = None
|
||||||
|
|
||||||
|
|
||||||
@app.route("/inpaint", methods=["POST"])
|
@app.route("/inpaint", methods=["POST"])
|
||||||
@ -101,11 +102,8 @@ def index():
|
|||||||
|
|
||||||
@app.route('/inputimage')
|
@app.route('/inputimage')
|
||||||
def set_input_photo():
|
def set_input_photo():
|
||||||
if input_image:
|
if input_image_path:
|
||||||
input_file = os.path.join(os.path.dirname(__file__), input_image)
|
with open(input_image_path, 'rb') as f:
|
||||||
if os.path.exists(input_file): # Check if file exists
|
|
||||||
if imghdr.what(input_file) is not None: # Check if file is image
|
|
||||||
with open(input_file, 'rb') as f:
|
|
||||||
image_in_bytes = f.read()
|
image_in_bytes = f.read()
|
||||||
return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg')
|
return send_file(io.BytesIO(image_in_bytes), mimetype='image/jpeg')
|
||||||
else:
|
else:
|
||||||
@ -138,18 +136,25 @@ def get_args_parser():
|
|||||||
parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int,
|
parser.add_argument("--gui-size", default=[1600, 1000], nargs=2, type=int,
|
||||||
help="Set window size for GUI")
|
help="Set window size for GUI")
|
||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
return parser.parse_args()
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.input is not None:
|
||||||
|
if not os.path.exists(args.input):
|
||||||
|
parser.error(f"invalid --input: {args.input} not exists")
|
||||||
|
if imghdr.what(args.input) is None:
|
||||||
|
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global model
|
global model
|
||||||
global device
|
global device
|
||||||
global input_image
|
global input_image_path
|
||||||
|
|
||||||
args = get_args_parser()
|
args = get_args_parser()
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
input_image_path = args.input
|
||||||
input_image = args.input
|
|
||||||
|
|
||||||
if args.model == "lama":
|
if args.model == "lama":
|
||||||
model = LaMa(crop_trigger_size=args.crop_trigger_size,
|
model = LaMa(crop_trigger_size=args.crop_trigger_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user