add MAT model

This commit is contained in:
Qing 2022-08-22 23:24:02 +08:00
parent a5e840765e
commit 6d2b24ed6b
8 changed files with 2132 additions and 9 deletions

View File

@ -26,6 +26,7 @@
1. [LaMa](https://github.com/saic-mdal/lama) 1. [LaMa](https://github.com/saic-mdal/lama)
1. [LDM](https://github.com/CompVis/latent-diffusion) 1. [LDM](https://github.com/CompVis/latent-diffusion)
1. [ZITS](https://github.com/DQiaole/ZITS_inpainting) 1. [ZITS](https://github.com/DQiaole/ZITS_inpainting)
1. [MAT](https://github.com/fenglinglwb/MAT)
- Support CPU & GPU - Support CPU & GPU
- Various high-resolution image processing [strategy](#high-resolution-strategy) - Various high-resolution image processing [strategy](#high-resolution-strategy)
- Run as a desktop APP - Run as a desktop APP
@ -36,7 +37,7 @@
| ---------------------- | --------------------------------------------- | --------------------------------------------------- | | ---------------------- | --------------------------------------------- | --------------------------------------------------- |
| Remove unwanted things | ![unwant_object2](./assets/unwant_object.jpg) | ![unwant_object2](./assets/unwant_object_clean.jpg) | | Remove unwanted things | ![unwant_object2](./assets/unwant_object.jpg) | ![unwant_object2](./assets/unwant_object_clean.jpg) |
| Remove unwanted person | ![unwant_person](./assets/unwant_person.jpg) | ![unwant_person](./assets/unwant_person_clean.jpg) | | Remove unwanted person | ![unwant_person](./assets/unwant_person.jpg) | ![unwant_person](./assets/unwant_person_clean.jpg) |
| Remove Text | ![text](./assets/unwant_text.jpg) | ![watermark_clean](./assets/unwant_text_clean.jpg) | | Remove Text | ![text](./assets/unwant_text.jpg) | ![text](./assets/unwant_text_clean.jpg) |
| Remove watermark | ![watermark](./assets/watermark.jpg) | ![watermark_clean](./assets/watermark_cleanup.jpg) | | Remove watermark | ![watermark](./assets/watermark.jpg) | ![watermark_clean](./assets/watermark_cleanup.jpg) |
| Fix old photo | ![oldphoto](./assets/old_photo.jpg) | ![oldphoto_clean](./assets/old_photo_clean.jpg) | | Fix old photo | ![oldphoto](./assets/old_photo.jpg) | ![oldphoto_clean](./assets/old_photo_clean.jpg) |
@ -69,6 +70,7 @@ Available arguments:
| LaMa | :+1: Generalizes well on high resolutions(~2k)<br/> | | | LaMa | :+1: Generalizes well on high resolutions(~2k)<br/> | |
| LDM | :+1: Possiblablity to get better and more detail result <br/> :+1: The balance of time and quality can be achieved by adjusting `steps` <br/> :neutral_face: Slower than GAN model<br/> :neutral_face: Need more GPU memory | `Steps`: You can get better result with large steps, but it will be more time-consuming <br/> `Sampler`: ddim or [plms](https://arxiv.org/abs/2202.09778). In general plms can get [better results](https://github.com/Sanster/lama-cleaner/releases/tag/0.13.0) with fewer steps | | LDM | :+1: Possiblablity to get better and more detail result <br/> :+1: The balance of time and quality can be achieved by adjusting `steps` <br/> :neutral_face: Slower than GAN model<br/> :neutral_face: Need more GPU memory | `Steps`: You can get better result with large steps, but it will be more time-consuming <br/> `Sampler`: ddim or [plms](https://arxiv.org/abs/2202.09778). In general plms can get [better results](https://github.com/Sanster/lama-cleaner/releases/tag/0.13.0) with fewer steps |
| ZITS | :+1: Better holistic structures compared with previous methods <br/> :neutral_face: Wireframe module is **very** slow on CPU | `Wireframe`: Enable edge and line detect | | ZITS | :+1: Better holistic structures compared with previous methods <br/> :neutral_face: Wireframe module is **very** slow on CPU | `Wireframe`: Enable edge and line detect |
| MAT | TODO | |
### LaMa vs LDM ### LaMa vs LDM

View File

@ -131,6 +131,8 @@ function ModelSettingBlock() {
return renderLDMModelDesc() return renderLDMModelDesc()
case AIModel.ZITS: case AIModel.ZITS:
return renderZITSModelDesc() return renderZITSModelDesc()
case AIModel.MAT:
return undefined
default: default:
return <></> return <></>
} }
@ -156,6 +158,12 @@ function ModelSettingBlock() {
'https://arxiv.org/abs/2203.00867', 'https://arxiv.org/abs/2203.00867',
'https://github.com/DQiaole/ZITS_inpainting' 'https://github.com/DQiaole/ZITS_inpainting'
) )
case AIModel.MAT:
return renderModelDesc(
'Mask-Aware Transformer for Large Hole Image Inpainting',
'https://arxiv.org/pdf/2203.15270.pdf',
'https://github.com/fenglinglwb/MAT'
)
default: default:
return <></> return <></>
} }

View File

@ -7,6 +7,7 @@ export enum AIModel {
LAMA = 'lama', LAMA = 'lama',
LDM = 'ldm', LDM = 'ldm',
ZITS = 'zits', ZITS = 'zits',
MAT = 'mat',
} }
export const fileState = atom<File | undefined>({ export const fileState = atom<File | undefined>({
@ -80,6 +81,12 @@ const defaultHDSettings: ModelsHDSettings = {
hdStrategyCropTrigerSize: 1024, hdStrategyCropTrigerSize: 1024,
hdStrategyCropMargin: 128, hdStrategyCropMargin: 128,
}, },
[AIModel.MAT]: {
hdStrategy: HDStrategy.CROP,
hdStrategyResizeLimit: 1024,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
},
} }
export const settingStateDefault: Settings = { export const settingStateDefault: Settings = {

View File

@ -53,6 +53,26 @@ def load_jit_model(url_or_path, device):
return model return model
def load_model(model: torch.nn.Module, url_or_path, device):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path)
try:
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict, strict=True)
model.to(device)
logger.info(f"Load model from: {model_path}")
except:
logger.error(
f"Failed to load {model_path}, delete model and restart lama-cleaner"
)
exit(-1)
model.eval()
return model
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode( data = cv2.imencode(
f".{ext}", f".{ext}",

2064
lama_cleaner/model/mat.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,14 @@
from lama_cleaner.model.lama import LaMa from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.mat import MAT
from lama_cleaner.model.zits import ZITS from lama_cleaner.model.zits import ZITS
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
models = { models = {
'lama': LaMa, 'lama': LaMa,
'ldm': LDM, 'ldm': LDM,
'zits': ZITS 'zits': ZITS,
'mat': MAT
} }

View File

@ -7,7 +7,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int) parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm", "zits"]) parser.add_argument("--model", default="lama", choices=["lama", "ldm", "zits", "mat"])
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"]) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument( parser.add_argument(

View File

@ -11,13 +11,19 @@ from lama_cleaner.schema import Config, HDStrategy, LDMSampler
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
def get_data(fx=1): def get_data(fx=1, fy=1.0):
img = cv2.imread(str(current_dir / "image.png")) img = cv2.imread(str(current_dir / "image.png"))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE) mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE)
# img = cv2.imread("/Users/qing/code/github/MAT/test_sets/Places/images/test1.jpg")
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# mask = cv2.imread("/Users/qing/code/github/MAT/test_sets/Places/masks/mask1.png", cv2.IMREAD_GRAYSCALE)
# mask = 255 - mask
if fx != 1: if fx != 1:
img = cv2.resize(img, None, fx=fx, fy=1) img = cv2.resize(img, None, fx=fx, fy=fy)
mask = cv2.resize(mask, None, fx=fx, fy=1) mask = cv2.resize(mask, None, fx=fx, fy=fy)
return img, mask return img, mask
@ -34,8 +40,8 @@ def get_config(strategy, **kwargs):
return Config(**data) return Config(**data)
def assert_equal(model, config, gt_name, fx=1): def assert_equal(model, config, gt_name, fx=1, fy=1):
img, mask = get_data(fx=fx) img, mask = get_data(fx=fx, fy=fy)
res = model(img, mask, config) res = model(img, mask, config)
cv2.imwrite( cv2.imwrite(
str(current_dir / gt_name), str(current_dir / gt_name),
@ -111,6 +117,20 @@ def test_zits(strategy, zits_wireframe):
assert_equal( assert_equal(
model, model,
cfg, cfg,
f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_fx_{fx}_result.png", f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
fx=fx, fx=fx,
) )
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL]
)
def test_mat(strategy):
model = ModelManager(name="mat", device="cpu")
cfg = get_config(strategy)
assert_equal(
model,
cfg,
f"mat_{strategy.capitalize()}_result.png",
)