diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index 17c9b6b..d2d1028 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -1,10 +1,12 @@ +import { Setting } from '../store/Atoms' import { dataURItoBlob } from '../utils' -export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint` +export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` export default async function inpaint( imageFile: File, maskBase64: string, + settings: Setting, sizeLimit?: string ) { // 1080, 2000, Original @@ -13,13 +15,22 @@ export default async function inpaint( const mask = dataURItoBlob(maskBase64) fd.append('mask', mask) + fd.append('ldmSteps', settings.ldmSteps.toString()) + fd.append('hdStrategy', settings.hdStrategy) + fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString()) + fd.append( + 'hdStrategyCropTrigerSize', + settings.hdStrategyCropTrigerSize.toString() + ) + fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString()) + if (sizeLimit === undefined) { fd.append('sizeLimit', '1080') } else { fd.append('sizeLimit', sizeLimit) } - const res = await fetch(API_ENDPOINT, { + const res = await fetch(`${API_ENDPOINT}/inpaint`, { method: 'POST', body: fd, }).then(async r => { @@ -28,3 +39,12 @@ export default async function inpaint( return URL.createObjectURL(res) } + +export function switchModel(name: string) { + const fd = new FormData() + fd.append('name', name) + return fetch(`${API_ENDPOINT}/switch_model`, { + method: 'POST', + body: fd, + }) +} diff --git a/lama_cleaner/app/src/components/Editor/Editor.tsx b/lama_cleaner/app/src/components/Editor/Editor.tsx index d987307..b5e3ca9 100644 --- a/lama_cleaner/app/src/components/Editor/Editor.tsx +++ b/lama_cleaner/app/src/components/Editor/Editor.tsx @@ -15,12 +15,14 @@ import { TransformComponent, TransformWrapper, } from 'react-zoom-pan-pinch' +import { useRecoilValue } from 'recoil' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' import inpaint from '../../adapters/inpainting' import Button from '../shared/Button' import Slider from './Slider' import SizeSelector from './SizeSelector' import { downloadImage, loadImage, useImage } from '../../utils' +import { settingState } from '../../store/Atoms' const TOOLBAR_SIZE = 200 const BRUSH_COLOR = '#ffcc00bb' @@ -57,6 +59,7 @@ function drawLines( export default function Editor(props: EditorProps) { const { file } = props + const settings = useRecoilValue(settingState) const [brushSize, setBrushSize] = useState(40) const [original, isOriginalLoaded] = useImage(file) const [renders, setRenders] = useState([]) @@ -125,6 +128,7 @@ export default function Editor(props: EditorProps) { const res = await inpaint( file, maskCanvas.toDataURL(), + settings, sizeLimit.toString() ) if (!res) { @@ -157,6 +161,7 @@ export default function Editor(props: EditorProps) { renders, sizeLimit, historyLineCount, + settings, ]) const hadDrawSomething = () => { diff --git a/lama_cleaner/app/src/components/Header/Header.tsx b/lama_cleaner/app/src/components/Header/Header.tsx index c0fb1cc..8a84e02 100644 --- a/lama_cleaner/app/src/components/Header/Header.tsx +++ b/lama_cleaner/app/src/components/Header/Header.tsx @@ -6,7 +6,7 @@ import Button from '../shared/Button' import Shortcuts from '../Shortcuts/Shortcuts' import useResolution from '../../hooks/useResolution' import { ThemeChanger } from './ThemeChanger' -import SettingIcon from '../Setting/SettingIcon' +import SettingIcon from '../Settings/SettingIcon' const Header = () => { const [file, setFile] = useRecoilState(fileState) diff --git a/lama_cleaner/app/src/components/Setting/HDSettingBlock.scss b/lama_cleaner/app/src/components/Settings/HDSettingBlock.scss similarity index 100% rename from lama_cleaner/app/src/components/Setting/HDSettingBlock.scss rename to lama_cleaner/app/src/components/Settings/HDSettingBlock.scss diff --git a/lama_cleaner/app/src/components/Setting/HDSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx similarity index 77% rename from lama_cleaner/app/src/components/Setting/HDSettingBlock.tsx rename to lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx index 2d87f19..c4ce001 100644 --- a/lama_cleaner/app/src/components/Setting/HDSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx @@ -1,50 +1,16 @@ import React, { ReactNode } from 'react' import { useRecoilState } from 'recoil' import { settingState } from '../../store/Atoms' -import NumberInput from '../shared/NumberInput' import Selector from '../shared/Selector' +import NumberInputSetting from './NumberInputSetting' import SettingBlock from './SettingBlock' export enum HDStrategy { ORIGINAL = 'Original', - REISIZE = 'Resize', + RESIZE = 'Resize', CROP = 'Crop', } -interface PixelSizeInputProps { - title: string - value: string - onValue: (val: string) => void -} - -function PixelSizeInputSetting(props: PixelSizeInputProps) { - const { title, value, onValue } = props - - return ( - - - pixel - - } - /> - ) -} - function HDSettingBlock() { const [setting, setSettingState] = useRecoilState(settingState) @@ -84,7 +50,7 @@ function HDSettingBlock() { tabIndex={0} role="button" className="inline-tip" - onClick={() => onStrategyChange(HDStrategy.REISIZE)} + onClick={() => onStrategyChange(HDStrategy.RESIZE)} > Resize Strategy {' '} @@ -100,9 +66,10 @@ function HDSettingBlock() { Resize the longer side of the image to a specific size(keep ratio), then do inpainting on the resized image. - @@ -117,14 +84,16 @@ function HDSettingBlock() { the result back. Mainly for performance and memory reasons on high resolution image. - - @@ -137,7 +106,7 @@ function HDSettingBlock() { return renderOriginalOptionDesc() case HDStrategy.CROP: return renderCropOptionDesc() - case HDStrategy.REISIZE: + case HDStrategy.RESIZE: return renderResizeOptionDesc() default: return renderOriginalOptionDesc() diff --git a/lama_cleaner/app/src/components/Setting/ModelSettingBlock.scss b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.scss similarity index 100% rename from lama_cleaner/app/src/components/Setting/ModelSettingBlock.scss rename to lama_cleaner/app/src/components/Settings/ModelSettingBlock.scss diff --git a/lama_cleaner/app/src/components/Setting/ModelSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx similarity index 74% rename from lama_cleaner/app/src/components/Setting/ModelSettingBlock.tsx rename to lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx index 2d7891c..dce8851 100644 --- a/lama_cleaner/app/src/components/Setting/ModelSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx @@ -2,11 +2,12 @@ import React, { ReactNode } from 'react' import { useRecoilState } from 'recoil' import { settingState } from '../../store/Atoms' import Selector from '../shared/Selector' +import NumberInputSetting from './NumberInputSetting' import SettingBlock from './SettingBlock' export enum AIModel { - LAMA = 'LaMa', - LDM = 'LDM', + LAMA = 'lama', + LDM = 'ldm', } function ModelSettingBlock() { @@ -24,7 +25,7 @@ function ModelSettingBlock() { githubUrl: string ) => { return ( -
+
-
-
{githubUrl} @@ -49,6 +47,28 @@ function ModelSettingBlock() { ) } + const renderLDMModelDesc = () => { + return ( +
+ {renderModelDesc( + 'High-Resolution Image Synthesis with Latent Diffusion Models', + 'https://arxiv.org/abs/2112.10752', + 'https://github.com/CompVis/latent-diffusion' + )} + { + const val = value.length === 0 ? 0 : parseInt(value, 10) + setSettingState(old => { + return { ...old, ldmSteps: val } + }) + }} + /> +
+ ) + } + const renderOptionDesc = (): ReactNode => { switch (setting.model) { case AIModel.LAMA: @@ -58,11 +78,7 @@ function ModelSettingBlock() { 'https://github.com/saic-mdal/lama' ) case AIModel.LDM: - return renderModelDesc( - 'High-Resolution Image Synthesis with Latent Diffusion Models', - 'https://arxiv.org/abs/2112.10752', - 'https://github.com/CompVis/latent-diffusion' - ) + return renderLDMModelDesc() default: return <> } diff --git a/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx b/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx new file mode 100644 index 0000000..543a92a --- /dev/null +++ b/lama_cleaner/app/src/components/Settings/NumberInputSetting.tsx @@ -0,0 +1,40 @@ +import React from 'react' +import NumberInput from '../shared/NumberInput' +import SettingBlock from './SettingBlock' + +interface NumberInputSettingProps { + title: string + value: string + suffix?: string + onValue: (val: string) => void +} + +function NumberInputSetting(props: NumberInputSettingProps) { + const { title, value, suffix, onValue } = props + + return ( + + + {suffix && {suffix}} +
+ } + /> + ) +} + +export default NumberInputSetting diff --git a/lama_cleaner/app/src/components/Setting/SavePathSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/SavePathSettingBlock.tsx similarity index 51% rename from lama_cleaner/app/src/components/Setting/SavePathSettingBlock.tsx rename to lama_cleaner/app/src/components/Settings/SavePathSettingBlock.tsx index 536e67a..06bc3dd 100644 --- a/lama_cleaner/app/src/components/Setting/SavePathSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/SavePathSettingBlock.tsx @@ -1,14 +1,26 @@ import React, { ReactNode } from 'react' import { useRecoilState } from 'recoil' +import { settingState } from '../../store/Atoms' import { Switch, SwitchThumb } from '../shared/Switch' import SettingBlock from './SettingBlock' function SavePathSettingBlock() { + const [setting, setSettingState] = useRecoilState(settingState) + + const onCheckChange = (checked: boolean) => { + setSettingState(old => { + return { ...old, saveImageBesideOrigin: checked } + }) + } + return ( + } diff --git a/lama_cleaner/app/src/components/Setting/SettingBlock.scss b/lama_cleaner/app/src/components/Settings/SettingBlock.scss similarity index 97% rename from lama_cleaner/app/src/components/Setting/SettingBlock.scss rename to lama_cleaner/app/src/components/Settings/SettingBlock.scss index 91a2ee3..eabc7ae 100644 --- a/lama_cleaner/app/src/components/Setting/SettingBlock.scss +++ b/lama_cleaner/app/src/components/Settings/SettingBlock.scss @@ -7,7 +7,7 @@ margin-top: 12px; border: 1px solid var(--border-color); border-radius: 0.3rem; - padding: 2rem; + padding: 1rem; .sub-setting-block { margin-top: 8px; diff --git a/lama_cleaner/app/src/components/Setting/SettingBlock.tsx b/lama_cleaner/app/src/components/Settings/SettingBlock.tsx similarity index 100% rename from lama_cleaner/app/src/components/Setting/SettingBlock.tsx rename to lama_cleaner/app/src/components/Settings/SettingBlock.tsx diff --git a/lama_cleaner/app/src/components/Setting/SettingIcon.tsx b/lama_cleaner/app/src/components/Settings/SettingIcon.tsx similarity index 100% rename from lama_cleaner/app/src/components/Setting/SettingIcon.tsx rename to lama_cleaner/app/src/components/Settings/SettingIcon.tsx diff --git a/lama_cleaner/app/src/components/Setting/Setting.scss b/lama_cleaner/app/src/components/Settings/Settings.scss similarity index 95% rename from lama_cleaner/app/src/components/Setting/Setting.scss rename to lama_cleaner/app/src/components/Settings/Settings.scss index 17adcb6..e743959 100644 --- a/lama_cleaner/app/src/components/Setting/Setting.scss +++ b/lama_cleaner/app/src/components/Settings/Settings.scss @@ -8,7 +8,6 @@ background-color: var(--modal-bg); color: var(--modal-text-color); box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2); - min-height: 600px; width: 700px; @include mobile { diff --git a/lama_cleaner/app/src/components/Setting/SettingModal.tsx b/lama_cleaner/app/src/components/Settings/SettingsModal.tsx similarity index 68% rename from lama_cleaner/app/src/components/Setting/SettingModal.tsx rename to lama_cleaner/app/src/components/Settings/SettingsModal.tsx index 262b9e1..c264e80 100644 --- a/lama_cleaner/app/src/components/Setting/SettingModal.tsx +++ b/lama_cleaner/app/src/components/Settings/SettingsModal.tsx @@ -1,11 +1,11 @@ import React from 'react' import { useRecoilState } from 'recoil' +import { switchModel } from '../../adapters/inpainting' import { settingState } from '../../store/Atoms' import Modal from '../shared/Modal' import HDSettingBlock from './HDSettingBlock' import ModelSettingBlock from './ModelSettingBlock' -import SavePathSettingBlock from './SavePathSettingBlock' export default function SettingModal() { const [setting, setSettingState] = useRecoilState(settingState) @@ -14,6 +14,8 @@ export default function SettingModal() { setSettingState(old => { return { ...old, show: false } }) + + switchModel(setting.model) } return ( @@ -23,7 +25,9 @@ export default function SettingModal() { className="modal-setting" show={setting.show} > - + {/* It's not possible because this poses a security risk */} + {/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */} + {/* */} diff --git a/lama_cleaner/app/src/components/Workspace.tsx b/lama_cleaner/app/src/components/Workspace.tsx index f084296..86b6852 100644 --- a/lama_cleaner/app/src/components/Workspace.tsx +++ b/lama_cleaner/app/src/components/Workspace.tsx @@ -1,9 +1,7 @@ import React from 'react' -import { useRecoilValue } from 'recoil' import Editor from './Editor/Editor' -import { shortcutsState } from '../store/Atoms' import ShortcutsModal from './Shortcuts/ShortcutsModal' -import SettingModal from './Setting/SettingModal' +import SettingModal from './Settings/SettingsModal' interface WorkspaceProps { file: File diff --git a/lama_cleaner/app/src/components/shared/Modal.tsx b/lama_cleaner/app/src/components/shared/Modal.tsx index e7db92e..f96de98 100644 --- a/lama_cleaner/app/src/components/shared/Modal.tsx +++ b/lama_cleaner/app/src/components/shared/Modal.tsx @@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) { const ref = useRef(null) useClickAway(ref, () => { - onClose?.() + if (show) { + onClose?.() + } }) useKeyPressEvent('Escape', e => { - onClose?.() + if (show) { + onClose?.() + } }) return ( diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index 82e2362..bad9e43 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -1,6 +1,6 @@ import { atom } from 'recoil' -import { HDStrategy } from '../components/Setting/HDSettingBlock' -import { AIModel } from '../components/Setting/ModelSettingBlock' +import { HDStrategy } from '../components/Settings/HDSettingBlock' +import { AIModel } from '../components/Settings/ModelSettingBlock' export const fileState = atom({ key: 'fileState', @@ -16,10 +16,15 @@ export interface Setting { show: boolean saveImageBesideOrigin: boolean model: AIModel + + // For LaMa hdStrategy: HDStrategy hdStrategyResizeLimit: number hdStrategyCropTrigerSize: number hdStrategyCropMargin: number + + // For LDM + ldmSteps: number } export const settingState = atom({ @@ -28,9 +33,10 @@ export const settingState = atom({ show: false, saveImageBesideOrigin: false, model: AIModel.LAMA, - hdStrategy: HDStrategy.ORIGINAL, + hdStrategy: HDStrategy.RESIZE, hdStrategyResizeLimit: 2048, hdStrategyCropTrigerSize: 2048, hdStrategyCropMargin: 128, + ldmSteps: 50, }, }) diff --git a/lama_cleaner/app/src/styles/_index.scss b/lama_cleaner/app/src/styles/_index.scss index e8362a6..d6e4210 100644 --- a/lama_cleaner/app/src/styles/_index.scss +++ b/lama_cleaner/app/src/styles/_index.scss @@ -11,7 +11,7 @@ @use '../components/Header/Header'; @use '../components/Header/ThemeChanger'; @use '../components/Shortcuts/Shortcuts'; -@use '../components/Setting/Setting.scss'; +@use '../components/Settings/Settings.scss'; // Shared @use '../components/FileSelect/FileSelect'; diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 101d894..2b7768b 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -31,7 +31,11 @@ def ceil_modulo(x, mod): def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: - data = cv2.imencode(f".{ext}", image_numpy)[1] + data = cv2.imencode(f".{ext}", image_numpy, + [ + int(cv2.IMWRITE_JPEG_QUALITY), 100, + int(cv2.IMWRITE_PNG_COMPRESSION), 0 + ])[1] image_bytes = data.tobytes() return image_bytes @@ -74,13 +78,24 @@ def resize_max_size( return np_img -def pad_img_to_modulo(img, mod): - channels, height, width = img.shape +def pad_img_to_modulo(img: np.ndarray, mod: int): + """ + + Args: + img: [H, W, C] + mod: + + Returns: + + """ + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + height, width = img.shape[:2] out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) return np.pad( img, - ((0, 0), (0, out_height - height), (0, out_width - width)), + ((0, out_height - height), (0, out_width - width), (0, 0)), mode="symmetric", ) @@ -88,15 +103,13 @@ def pad_img_to_modulo(img, mod): def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: """ Args: - mask: (1, h, w) 0~1 + mask: (h, w, 1) 0~255 Returns: """ - height, width = mask.shape[1:] - _, thresh = cv2.threshold( - (mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0 - ) + height, width = mask.shape[:2] + _, thresh = cv2.threshold(mask, 127, 255, 0) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] diff --git a/lama_cleaner/lama/__init__.py b/lama_cleaner/lama/__init__.py deleted file mode 100644 index a30a1f7..0000000 --- a/lama_cleaner/lama/__init__.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -from typing import List - -import cv2 -import torch -import numpy as np - -from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask - -LAMA_MODEL_URL = os.environ.get( - "LAMA_MODEL_URL", - "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", -) - - -class LaMa: - def __init__(self, crop_trigger_size: List[int], crop_margin: int, device): - """ - - Args: - crop_trigger_size: h, w - crop_margin: - device: - """ - self.crop_trigger_size = crop_trigger_size - self.crop_margin = crop_margin - self.device = device - - if os.environ.get("LAMA_MODEL"): - model_path = os.environ.get("LAMA_MODEL") - if not os.path.exists(model_path): - raise FileNotFoundError( - f"lama torchscript model not found: {model_path}" - ) - else: - model_path = download_model(LAMA_MODEL_URL) - - print(f"Load LaMa model from: {model_path}") - model = torch.jit.load(model_path, map_location="cpu") - model = model.to(device) - model.eval() - self.model = model - - @torch.no_grad() - def __call__(self, image, mask): - """ - image: [C, H, W] RGB - mask: [1, H, W] - return: BGR IMAGE - """ - area = image.shape[1] * image.shape[2] - if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]: - return self._run(image, mask) - - print("Trigger crop image") - boxes = boxes_from_mask(mask) - crop_result = [] - for box in boxes: - crop_image, crop_box = self._run_box(image, mask, box) - crop_result.append((crop_image, crop_box)) - - image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1] - for crop_image, crop_box in crop_result: - x1, y1, x2, y2 = crop_box - image[y1:y2, x1:x2, :] = crop_image - return image - - def _run_box(self, image, mask, box): - """ - - Args: - image: [C, H, W] RGB - mask: [1, H, W] - box: [left,top,right,bottom] - - Returns: - BGR IMAGE - """ - box_h = box[3] - box[1] - box_w = box[2] - box[0] - cx = (box[0] + box[2]) // 2 - cy = (box[1] + box[3]) // 2 - img_h, img_w = image.shape[1:] - - w = box_w + self.crop_margin * 2 - h = box_h + self.crop_margin * 2 - - l = max(cx - w // 2, 0) - t = max(cy - h // 2, 0) - r = min(cx + w // 2, img_w) - b = min(cy + h // 2, img_h) - - crop_img = image[:, t:b, l:r] - crop_mask = mask[:, t:b, l:r] - - print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") - - return self._run(crop_img, crop_mask), [l, t, r, b] - - def _run(self, image, mask): - """ - image: [C, H, W] RGB - mask: [1, H, W] - return: BGR IMAGE - """ - device = self.device - origin_height, origin_width = image.shape[1:] - image = pad_img_to_modulo(image, mod=8) - mask = pad_img_to_modulo(mask, mod=8) - - mask = (mask > 0) * 1 - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) - - inpainted_image = self.model(image, mask) - - cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() - cur_res = cur_res[0:origin_height, 0:origin_width, :] - cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") - cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) - return cur_res diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py new file mode 100644 index 0000000..31059dc --- /dev/null +++ b/lama_cleaner/model/base.py @@ -0,0 +1,122 @@ +import abc + +import cv2 +import torch +from loguru import logger + +from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo +from lama_cleaner.schema import Config, HDStrategy + + +class InpaintModel: + pad_mod = 8 + + def __init__(self, device): + """ + + Args: + device: + """ + self.device = device + self.init_model(device) + + @abc.abstractmethod + def init_model(self, device): + ... + + @abc.abstractmethod + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W] + return: BGR IMAGE + """ + ... + + def _pad_forward(self, image, mask, config: Config): + origin_height, origin_width = image.shape[:2] + padd_image = pad_img_to_modulo(image, mod=self.pad_mod) + padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod) + result = self.forward(padd_image, padd_mask, config) + result = result[0:origin_height, 0:origin_width, :] + + original_pixel_indices = mask != 255 + result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] + return result + + @torch.no_grad() + def __call__(self, image, mask, config: Config): + """ + image: [H, W, C] RGB, not normalized + mask: [H, W] + return: BGR IMAGE + """ + inpaint_result = None + logger.info(f"hd_strategy: {config.hd_strategy}") + if config.hd_strategy == HDStrategy.CROP: + if max(image.shape) > config.hd_strategy_crop_trigger_size: + logger.info(f"Run crop strategy") + boxes = boxes_from_mask(mask) + crop_result = [] + for box in boxes: + crop_image, crop_box = self._run_box(image, mask, box, config) + crop_result.append((crop_image, crop_box)) + + inpaint_result = image[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + inpaint_result[y1:y2, x1:x2, :] = crop_image + + elif config.hd_strategy == HDStrategy.RESIZE: + if max(image.shape) > config.hd_strategy_resize_limit: + origin_size = image.shape[:2] + downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit) + downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit) + + logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}") + inpaint_result = self._pad_forward(downsize_image, downsize_mask, config) + + # only paste masked area result + inpaint_result = cv2.resize(inpaint_result, + (origin_size[1], origin_size[0]), + interpolation=cv2.INTER_CUBIC) + + original_pixel_indices = mask != 255 + inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] + + if inpaint_result is None: + inpaint_result = self._pad_forward(image, mask, config) + + return inpaint_result + + def _run_box(self, image, mask, box, config: Config): + """ + + Args: + image: [H, W, C] RGB + mask: [H, W, 1] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE + """ + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cx = (box[0] + box[2]) // 2 + cy = (box[1] + box[3]) // 2 + img_h, img_w = image.shape[:2] + + w = box_w + config.hd_strategy_crop_margin * 2 + h = box_h + config.hd_strategy_crop_margin * 2 + + l = max(cx - w // 2, 0) + t = max(cy - h // 2, 0) + r = min(cx + w // 2, img_w) + b = min(cy + h // 2, img_h) + + crop_img = image[t:b, l:r, :] + crop_mask = mask[t:b, l:r] + + logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") + + return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py new file mode 100644 index 0000000..20ea2f8 --- /dev/null +++ b/lama_cleaner/model/lama.py @@ -0,0 +1,64 @@ +import os + +import cv2 +import numpy as np +import torch +from loguru import logger + +from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +LAMA_MODEL_URL = os.environ.get( + "LAMA_MODEL_URL", + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", +) + + +class LaMa(InpaintModel): + pad_mod = 8 + + def __init__(self, device): + """ + + Args: + device: + """ + super().__init__(device) + self.device = device + + def init_model(self, device): + if os.environ.get("LAMA_MODEL"): + model_path = os.environ.get("LAMA_MODEL") + if not os.path.exists(model_path): + raise FileNotFoundError( + f"lama torchscript model not found: {model_path}" + ) + else: + model_path = download_model(LAMA_MODEL_URL) + + logger.info(f"Load LaMa model from: {model_path}") + model = torch.jit.load(model_path, map_location="cpu") + model = model.to(device) + model.eval() + self.model = model + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W] + return: BGR IMAGE + """ + image = norm_img(image) + mask = norm_img(mask) + + mask = (mask > 0) * 1 + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + + inpainted_image = self.model(image, mask) + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res diff --git a/lama_cleaner/ldm/__init__.py b/lama_cleaner/model/ldm.py similarity index 76% rename from lama_cleaner/ldm/__init__.py rename to lama_cleaner/model/ldm.py index fa077ec..0b433b9 100644 --- a/lama_cleaner/ldm/__init__.py +++ b/lama_cleaner/model/ldm.py @@ -2,13 +2,16 @@ import os import numpy as np import torch +from loguru import logger + +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config torch.manual_seed(42) import torch.nn as nn from tqdm import tqdm -import cv2 -from lama_cleaner.helper import pad_img_to_modulo, download_model -from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \ +from lama_cleaner.helper import download_model, norm_img +from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \ timestep_embedding LDM_ENCODE_MODEL_URL = os.environ.get( @@ -217,7 +220,7 @@ class DDIMSampler(object): time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") + logger.info(f"Running DDIM Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) @@ -268,97 +271,39 @@ def load_jit_model(url, device): return model -class LDM: - def __init__(self, device, steps=50): +class LDM(InpaintModel): + pad_mod = 32 + + def __init__(self, device): + super().__init__(device) self.device = device + def init_model(self, device): self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) model = LatentDiffusion(self.diffusion_model, device) self.sampler = DDIMSampler(model) - self.steps = steps - def _norm(self, tensor): - return tensor * 2.0 - 1.0 - - @torch.no_grad() - def __call__(self, image, mask): + def forward(self, image, mask, config: Config): """ - image: [C, H, W] RGB - mask: [1, H, W] + image: [H, W, C] RGB + mask: [H, W, 1] return: BGR IMAGE """ # image [1,3,512,512] float32 # mask: [1,1,512,512] float32 # masked_image: [1,3,512,512] float32 - origin_height, origin_width = image.shape[1:] - image = pad_img_to_modulo(image, mod=32) - mask = pad_img_to_modulo(mask, mod=32) - padded_height, padded_width = image.shape[1:] + steps = config.ldm_steps + image = norm_img(image) + mask = norm_img(mask) + mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 - # crop 512 x 512 - if padded_width <= 512 or padded_height <= 512: - np_img = self._forward(image, mask, self.device) - else: - print("Try to zoom in") - # zoom in - # x,y,w,h - # box = self.box_from_bitmap(mask) - box = self.find_main_content(mask) - if box is None: - print("No bbox found") - np_img = self._forward(image, mask, self.device) - else: - print(f"box: {box}") - box_x, box_y, box_w, box_h = box - cx = box_x + box_w // 2 - cy = box_y + box_h // 2 - - # w = max(512, box_w) - # h = max(512, box_h) - w = box_w + 512 - h = box_h + 512 - - left = max(cx - w // 2, 0) - top = max(cy - h // 2, 0) - right = min(cx + w // 2, origin_width) - bottom = min(cy + h // 2, origin_height) - - x = left - y = top - w = right - left - h = bottom - top - - crop_img = image[:, int(y):int(y + h), int(x):int(x + w)] - crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)] - - print(f"Apply zoom in size width x height: {crop_img.shape}") - - crop_img_height, crop_img_width = crop_img.shape[1:] - - crop_img = pad_img_to_modulo(crop_img, mod=32) - crop_mask = pad_img_to_modulo(crop_mask, mod=32) - # RGB - np_img = self._forward(crop_img, crop_mask, self.device) - - image = (image.transpose(1, 2, 0) * 255).astype(np.uint8) - image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :] - np_img = image - # BGR to RGB - # np_img = image[:, :, ::-1] - - np_img = np_img[0:origin_height, 0:origin_width, :] - np_img = np_img[:, :, ::-1] - - return np_img - - def _forward(self, image, mask, device): - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) masked_image = (1 - mask) * image image = self._norm(image) @@ -371,47 +316,20 @@ class LDM: c = torch.cat((c, cc), dim=1) # 1,4,128,128 shape = (c.shape[1] - 1,) + c.shape[2:] - samples_ddim = self.sampler.sample(steps=self.steps, + samples_ddim = self.sampler.sample(steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape) x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32 - image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) - mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) - predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) + inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1 - mask) * image + mask * predicted_image - inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 - np_img = inpainted.astype(np.uint8) - return np_img + # inpainted = (1 - mask) * image + mask * predicted_image + inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 + inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1] + return inpainted_image - def find_main_content(self, bitmap: np.ndarray): - th2 = bitmap[0].astype(np.uint8) - row_sum = th2.sum(1) - col_sum = th2.sum(0) - xmin = max(0, np.argwhere(col_sum != 0).min() - 20) - xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1]) - ymin = max(0, np.argwhere(row_sum != 0).min() - 20) - ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0]) - - left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax) - return left, top, right - left, bottom - top - - def box_from_bitmap(self, bitmap): - """ - bitmap: single map with shape (NUM_CLASSES, H, W), - whose values are binarized as {0, 1} - """ - contours, _ = cv2.findContours( - (bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE - ) - - contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True) - num_contours = len(contours) - print(f"contours size: {num_contours}") - if num_contours != 1: - return None - - # x,y,w,h - return cv2.boundingRect(contours[0]) + def _norm(self, tensor): + return tensor * 2.0 - 1.0 diff --git a/lama_cleaner/ldm/utils.py b/lama_cleaner/model/utils.py similarity index 100% rename from lama_cleaner/ldm/utils.py rename to lama_cleaner/model/utils.py diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py new file mode 100644 index 0000000..105561e --- /dev/null +++ b/lama_cleaner/model_manager.py @@ -0,0 +1,34 @@ +from lama_cleaner.model.lama import LaMa +from lama_cleaner.model.ldm import LDM +from lama_cleaner.schema import Config + + +class ModelManager: + LAMA = 'lama' + LDM = 'ldm' + + def __init__(self, name: str, device): + self.name = name + self.device = device + self.model = self.init_model(name, device) + + def init_model(self, name: str, device): + if name == self.LAMA: + model = LaMa(device) + elif name == self.LDM: + model = LDM(device) + else: + raise NotImplementedError(f"Not supported model: {name}") + return model + + def __call__(self, image, mask, config: Config): + return self.model(image, mask, config) + + def switch(self, new_name: str): + if new_name == self.name: + return + try: + self.model = self.init_model(new_name, self.device) + self.name = new_name + except NotImplementedError as e: + raise e diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py new file mode 100644 index 0000000..0d6fad8 --- /dev/null +++ b/lama_cleaner/schema.py @@ -0,0 +1,17 @@ +from enum import Enum + +from pydantic import BaseModel + + +class HDStrategy(str, Enum): + ORIGINAL = 'Original' + RESIZE = 'Resize' + CROP = 'Crop' + + +class Config(BaseModel): + ldm_steps: int + hd_strategy: str + hd_strategy_crop_margin: int + hd_strategy_crop_trigger_size: int + hd_strategy_resize_limit: int diff --git a/lama_cleaner/tests/__init__.py b/lama_cleaner/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lama_cleaner/tests/image.png b/lama_cleaner/tests/image.png new file mode 100644 index 0000000..74c7a7b Binary files /dev/null and b/lama_cleaner/tests/image.png differ diff --git a/lama_cleaner/tests/lama_crop_result.png b/lama_cleaner/tests/lama_crop_result.png new file mode 100644 index 0000000..ce37e3c Binary files /dev/null and b/lama_cleaner/tests/lama_crop_result.png differ diff --git a/lama_cleaner/tests/lama_original_result.png b/lama_cleaner/tests/lama_original_result.png new file mode 100644 index 0000000..d91e7b3 Binary files /dev/null and b/lama_cleaner/tests/lama_original_result.png differ diff --git a/lama_cleaner/tests/lama_resize_result.png b/lama_cleaner/tests/lama_resize_result.png new file mode 100644 index 0000000..70a940c Binary files /dev/null and b/lama_cleaner/tests/lama_resize_result.png differ diff --git a/lama_cleaner/tests/ldm_crop_result.png b/lama_cleaner/tests/ldm_crop_result.png new file mode 100644 index 0000000..4fdda35 Binary files /dev/null and b/lama_cleaner/tests/ldm_crop_result.png differ diff --git a/lama_cleaner/tests/ldm_original_result.png b/lama_cleaner/tests/ldm_original_result.png new file mode 100644 index 0000000..2013757 Binary files /dev/null and b/lama_cleaner/tests/ldm_original_result.png differ diff --git a/lama_cleaner/tests/ldm_resize_result.png b/lama_cleaner/tests/ldm_resize_result.png new file mode 100644 index 0000000..c893f1c Binary files /dev/null and b/lama_cleaner/tests/ldm_resize_result.png differ diff --git a/lama_cleaner/tests/mask.jpg b/lama_cleaner/tests/mask.jpg deleted file mode 100644 index a2aec11..0000000 Binary files a/lama_cleaner/tests/mask.jpg and /dev/null differ diff --git a/lama_cleaner/tests/mask.png b/lama_cleaner/tests/mask.png new file mode 100644 index 0000000..29cf20b Binary files /dev/null and b/lama_cleaner/tests/mask.png differ diff --git a/lama_cleaner/tests/test_boxes_from_mask.py b/lama_cleaner/tests/test_boxes_from_mask.py deleted file mode 100644 index 3faa4c6..0000000 --- a/lama_cleaner/tests/test_boxes_from_mask.py +++ /dev/null @@ -1,15 +0,0 @@ -import cv2 -import numpy as np - -from lama_cleaner.helper import boxes_from_mask - - -def test_boxes_from_mask(): - mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE) - mask = mask[:, :, np.newaxis] - mask = (mask / 255).transpose(2, 0, 1) - boxes = boxes_from_mask(mask) - print(boxes) - - -test_boxes_from_mask() diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py new file mode 100644 index 0000000..0b2b3c3 --- /dev/null +++ b/lama_cleaner/tests/test_model.py @@ -0,0 +1,55 @@ +import os +from pathlib import Path + +import cv2 +import numpy as np +import pytest + +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config, HDStrategy + +current_dir = Path(__file__).parent.absolute().resolve() + + +def get_data(): + img = cv2.imread(str(current_dir / 'image.png')) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE) + return img, mask + + +def get_config(strategy): + return Config( + ldm_steps=1, + hd_strategy=strategy, + hd_strategy_crop_margin=32, + hd_strategy_crop_trigger_size=200, + hd_strategy_resize_limit=200, + ) + + +def assert_equal(model, config, gt_name): + img, mask = get_data() + res = model(img, mask, config) + # cv2.imwrite(gt_name, res, + # [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0]) + + """ + Note that JPEG is lossy compression, so even if it is the highest quality 100, + when the saved image is reloaded, a difference occurs with the original pixel value. + If you want to save the original image as it is, save it as PNG or BMP. + """ + gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED) + assert np.array_equal(res, gt) + + +@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]) +def test_lama(strategy): + model = ModelManager(name='lama', device='cpu') + assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png') + + +@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]) +def test_ldm(strategy): + model = ModelManager(name='ldm', device='cpu') + assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png') diff --git a/main.py b/main.py index 2cff582..0c917d6 100644 --- a/main.py +++ b/main.py @@ -2,19 +2,21 @@ import argparse import io +import logging import multiprocessing import os import time import imghdr +from pathlib import Path from typing import Union import cv2 import torch import numpy as np -from lama_cleaner.lama import LaMa -from lama_cleaner.ldm import LDM +from loguru import logger -from flaskwebgui import FlaskUI +from lama_cleaner.model_manager import ModelManager +from lama_cleaner.schema import Config try: torch._C._jit_override_can_fuse_on_cpu(False) @@ -29,7 +31,6 @@ from flask_cors import CORS from lama_cleaner.helper import ( load_img, - norm_img, numpy_to_bytes, resize_max_size, ) @@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"): BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") + +class InterceptHandler(logging.Handler): + def emit(self, record): + logger_opt = logger.opt(depth=6, exception=record.exc_info) + logger_opt.log(record.levelno, record.getMessage()) + + app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app.config["JSON_AS_ASCII"] = False -CORS(app) +app.logger.addHandler(InterceptHandler()) +CORS(app, expose_headers=["Content-Disposition"]) -model = None +model: ModelManager = None device = None input_image_path: str = None @@ -72,24 +81,31 @@ def process(): original_shape = image.shape interpolation = cv2.INTER_CUBIC - size_limit: Union[int, str] = request.form.get("sizeLimit", "1080") + form = request.form + size_limit: Union[int, str] = form.get("sizeLimit", "1080") if size_limit == "Original": size_limit = max(image.shape) else: size_limit = int(size_limit) - print(f"Origin image shape: {original_shape}") + config = Config( + ldm_steps=form['ldmSteps'], + hd_strategy=form['hdStrategy'], + hd_strategy_crop_margin=form['hdStrategyCropMargin'], + hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'], + hd_strategy_resize_limit=form['hdStrategyResizeLimit'], + ) + + logger.info(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) - print(f"Resized image shape: {image.shape}") - image = norm_img(image) + logger.info(f"Resized image shape: {image.shape}") mask, _ = load_img(input["mask"].read(), gray=True) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) - mask = norm_img(mask) start = time.time() - res_np_img = model(image, mask) - print(f"process time: {(time.time() - start) * 1000}ms") + res_np_img = model(image, mask, config) + logger.info(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache() @@ -109,6 +125,19 @@ def process(): ) +@app.route("/switch_model", methods=["POST"]) +def switch_model(): + new_name = request.form.get("name") + if new_name == model.name: + return "Same model", 200 + + try: + model.switch(new_name) + except NotImplementedError: + return f"{new_name} not implemented", 403 + return f"ok, switch to {new_name}", 200 + + @app.route("/") def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) @@ -120,7 +149,9 @@ def set_input_photo(): with open(input_image_path, "rb") as f: image_in_bytes = f.read() return send_file( - io.BytesIO(image_in_bytes), + input_image_path, + as_attachment=True, + download_name=Path(input_image_path).name, mimetype=f"image/{get_image_ext(image_in_bytes)}", ) else: @@ -135,29 +166,6 @@ def get_args_parser(): parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", default=8080, type=int) parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) - parser.add_argument( - "--crop-trigger-size", - default=[2042, 2042], - nargs=2, - type=int, - help="If image size large then crop-trigger-size, " - "crop each area from original image to do inference." - "Mainly for performance and memory reasons" - "Only for lama", - ) - parser.add_argument( - "--crop-margin", - type=int, - default=256, - help="Margin around bounding box of painted stroke when crop mode triggered", - ) - parser.add_argument( - "--ldm-steps", - default=50, - type=int, - help="Steps for DDIM sampling process." - "The larger the value, the better the result, but it will be more time-consuming", - ) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument( @@ -188,19 +196,11 @@ def main(): device = torch.device(args.device) input_image_path = args.input - if args.model == "lama": - model = LaMa( - crop_trigger_size=args.crop_trigger_size, - crop_margin=args.crop_margin, - device=device, - ) - elif args.model == "ldm": - model = LDM(device, steps=args.ldm_steps) - else: - raise NotImplementedError(f"Not supported model: {args.model}") + model = ModelManager(name=args.model, device=device) if args.gui: app_width, app_height = args.gui_size + from flaskwebgui import FlaskUI ui = FlaskUI(app, width=app_width, height=app_height) ui.run() else: diff --git a/requirements.txt b/requirements.txt index 5bb3a13..c0e911e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ torch>=1.8.2 opencv-python flask_cors -flask +flask==2.1.1 flaskwebgui tqdm +pydantic +loguru +pytest