diff --git a/iopaint/api.py b/iopaint/api.py index 8fa7bfd..cace639 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -167,6 +167,7 @@ class Api: self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) + self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") # fmt: on @@ -179,6 +180,12 @@ class Api: def add_api_route(self, path: str, endpoint, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) + def api_save_image(self, file: UploadFile): + filename = file.filename + origin_image_bytes = file.file.read() + with open(self.config.output_dir / filename, "wb") as fw: + fw.write(origin_image_bytes) + def api_current_model(self) -> ModelInfo: return self.model_manager.current_model diff --git a/iopaint/cli.py b/iopaint/cli.py index 88fda7b..951fbb4 100644 --- a/iopaint/cli.py +++ b/iopaint/cli.py @@ -144,7 +144,10 @@ def start( device = check_device(device) if input and not input.exists(): logger.error(f"invalid --input: {input} not exists") - exit() + exit(-1) + if input and input.is_dir() and not output_dir: + logger.error(f"invalid --output-dir: must be set when --input is a directory") + exit(-1) if output_dir: output_dir = output_dir.expanduser().absolute() logger.info(f"Image will be saved to {output_dir}") diff --git a/iopaint/file_manager/file_manager.py b/iopaint/file_manager/file_manager.py index cb33278..413162c 100644 --- a/iopaint/file_manager/file_manager.py +++ b/iopaint/file_manager/file_manager.py @@ -27,18 +27,11 @@ class FileManager: self.thumbnail_directory.mkdir(parents=True) # fmt: off - self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse]) self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"]) self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"]) # fmt: on - def api_save_image(self, file: UploadFile): - filename = file.filename - origin_image_bytes = file.file.read() - with open(self.output_dir / filename, "wb") as fw: - fw.write(origin_image_bytes) - def api_medias(self, tab: MediaTab) -> List[MediasResponse]: img_dir = self._get_dir(tab) return self._media_names(img_dir) diff --git a/web_app/src/components/Settings.tsx b/web_app/src/components/Settings.tsx index 5138e03..18a021c 100644 --- a/web_app/src/components/Settings.tsx +++ b/web_app/src/components/Settings.tsx @@ -644,7 +644,6 @@ export function SettingsDialog() {