Merge pull request #51 from Sanster/plms

add plms sampler
This commit is contained in:
Qing 2022-06-12 22:20:35 +08:00 committed by GitHub
commit d43d0694c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 536 additions and 224 deletions

View File

@ -36,10 +36,40 @@ Available commands:
| --port | Port for flask web server | 8080 | | --port | Port for flask web server | 8080 |
| --debug | Enable debug mode for flask web server | | | --debug | Enable debug mode for flask web server | |
## Settings
You can change the configs of inpainting process in the settings interface of the web page.
<img src="./assets/settings.png" width="400px">
### Inpainting Model
Select the inpainting model to use, and set the configs corresponding to the model.
LaMa model has no configs that can be specified at runtime.
LDM model has two configs to control the quality of final result:
1. Steps: You can get better result with large steps, but it will be more time-consuming
2. Sampler: ddim or [plms](https://arxiv.org/abs/2202.09778). In general plms can get better results with fewer steps
### High Resolution Strategy
There are three strategies for handling high-resolution images.
- **Original**: Use the original resolution of the picture, suitable for picture size below 2K.
- **Resize**: Resize the longer side of the image to a specific size(keep ratio), then do inpainting on the resized image.
The inpainting result will be pasted back on the original image to make sure other part of image not loss quality.
- **Crop**: Crop masking area from the original image to do inpainting, and paste the result back.
Mainly for performance and memory reasons on high resolution image. This strategy may give better results for ldm model.
## Model Comparison ## Model Comparison
Diffusion model(ldm) is **MUCH MORE** slower than GANs(lama)(1080x720 image takes 8s on 3090), but it's possible to get better | Model | Pron | Corn |
result, see below example: |-------|-----------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
| LaMa | - Perform will on high resolution image(~2k)<br/> - Faster than diffusion model | |
| LDM | - It's possible to get better and more detail result, see example below<br/> - The balance of time and quality can be achieved by steps | - Slower than GAN model<br/> - Need more GPU memory<br/> - Not good for high resolution images |
| Original Image | LaMa | LDM | | Original Image | LaMa | LDM |
| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------- |

BIN
assets/settings.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@ -1,7 +1,7 @@
{ {
"files": { "files": {
"main.css": "/static/css/main.d1028b29.chunk.css", "main.css": "/static/css/main.bdd774f1.chunk.css",
"main.js": "/static/js/main.2711860b.chunk.js", "main.js": "/static/js/main.9f4ed6dc.chunk.js",
"runtime-main.js": "/static/js/runtime-main.5e86ac81.js", "runtime-main.js": "/static/js/runtime-main.5e86ac81.js",
"static/js/2.1b1d3019.chunk.js": "/static/js/2.1b1d3019.chunk.js", "static/js/2.1b1d3019.chunk.js": "/static/js/2.1b1d3019.chunk.js",
"index.html": "/index.html", "index.html": "/index.html",
@ -11,7 +11,7 @@
"entrypoints": [ "entrypoints": [
"static/js/runtime-main.5e86ac81.js", "static/js/runtime-main.5e86ac81.js",
"static/js/2.1b1d3019.chunk.js", "static/js/2.1b1d3019.chunk.js",
"static/css/main.d1028b29.chunk.css", "static/css/main.bdd774f1.chunk.css",
"static/js/main.2711860b.chunk.js" "static/js/main.9f4ed6dc.chunk.js"
] ]
} }

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.d1028b29.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.1b1d3019.chunk.js"></script><script src="/static/js/main.2711860b.chunk.js"></script></body></html> <!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.bdd774f1.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.1b1d3019.chunk.js"></script><script src="/static/js/main.9f4ed6dc.chunk.js"></script></body></html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -16,6 +16,7 @@ export default async function inpaint(
fd.append('mask', mask) fd.append('mask', mask)
fd.append('ldmSteps', settings.ldmSteps.toString()) fd.append('ldmSteps', settings.ldmSteps.toString())
fd.append('ldmSampler', settings.ldmSampler.toString())
fd.append('hdStrategy', settings.hdStrategy) fd.append('hdStrategy', settings.hdStrategy)
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString()) fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
fd.append( fd.append(
@ -34,7 +35,11 @@ export default async function inpaint(
method: 'POST', method: 'POST',
body: fd, body: fd,
}).then(async r => { }).then(async r => {
return r.blob() console.log(r)
if (r.ok) {
return r.blob()
}
throw new Error('Something went wrong on server side.')
}) })
return URL.createObjectURL(res) return URL.createObjectURL(res)

View File

@ -15,7 +15,7 @@ import {
TransformComponent, TransformComponent,
TransformWrapper, TransformWrapper,
} from 'react-zoom-pan-pinch' } from 'react-zoom-pan-pinch'
import { useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint from '../../adapters/inpainting' import inpaint from '../../adapters/inpainting'
import Button from '../shared/Button' import Button from '../shared/Button'
@ -28,7 +28,7 @@ import {
loadImage, loadImage,
useImage, useImage,
} from '../../utils' } from '../../utils'
import { settingState } from '../../store/Atoms' import { settingState, toastState } from '../../store/Atoms'
const TOOLBAR_SIZE = 200 const TOOLBAR_SIZE = 200
const BRUSH_COLOR = '#ffcc00bb' const BRUSH_COLOR = '#ffcc00bb'
@ -73,6 +73,7 @@ function mouseXY(ev: SyntheticEvent) {
export default function Editor(props: EditorProps) { export default function Editor(props: EditorProps) {
const { file } = props const { file } = props
const settings = useRecoilValue(settingState) const settings = useRecoilValue(settingState)
const [toastVal, setToastState] = useRecoilState(toastState)
const [brushSize, setBrushSize] = useState(40) const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file) const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([]) const [renders, setRenders] = useState<HTMLImageElement[]>([])
@ -144,12 +145,11 @@ export default function Editor(props: EditorProps) {
} }
const newLineGroups = [...lineGroups, curLineGroup] const newLineGroups = [...lineGroups, curLineGroup]
setLineGroups(newLineGroups)
setCurLineGroup([]) setCurLineGroup([])
setIsDraging(false) setIsDraging(false)
setIsInpaintingLoading(true) setIsInpaintingLoading(true)
drawAllLinesOnMask(newLineGroups) drawAllLinesOnMask(newLineGroups)
try { try {
const res = await inpaint( const res = await inpaint(
file, file,
@ -165,9 +165,16 @@ export default function Editor(props: EditorProps) {
const newRenders = [...renders, newRender] const newRenders = [...renders, newRender]
setRenders(newRenders) setRenders(newRenders)
draw(newRender, []) draw(newRender, [])
// Only append new LineGroup after inpainting success
setLineGroups(newLineGroups)
} catch (e: any) { } catch (e: any) {
// eslint-disable-next-line setToastState({
alert(e.message ? e.message : e.toString()) open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 2000,
})
drawOnCurrentRender([])
} }
setIsInpaintingLoading(false) setIsInpaintingLoading(false)
} }

View File

@ -11,6 +11,11 @@ export enum HDStrategy {
CROP = 'Crop', CROP = 'Crop',
} }
export enum LDMSampler {
ddim = 'ddim',
plms = 'plms',
}
function HDSettingBlock() { function HDSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState) const [setting, setSettingState] = useRecoilState(settingState)

View File

@ -2,6 +2,7 @@ import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms' import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector' import Selector from '../shared/Selector'
import { LDMSampler } from './HDSettingBlock'
import NumberInputSetting from './NumberInputSetting' import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock' import SettingBlock from './SettingBlock'
@ -19,6 +20,12 @@ function ModelSettingBlock() {
}) })
} }
const onLDMSamplerChange = (value: LDMSampler) => {
setSettingState(old => {
return { ...old, ldmSampler: value }
})
}
const renderModelDesc = ( const renderModelDesc = (
name: string, name: string,
paperUrl: string, paperUrl: string,
@ -65,6 +72,19 @@ function ModelSettingBlock() {
}) })
}} }}
/> />
<SettingBlock
className="sub-setting-block"
title="Sampler"
input={
<Selector
width={80}
value={setting.ldmSampler as string}
options={Object.values(LDMSampler)}
onChange={val => onLDMSamplerChange(val as LDMSampler)}
/>
}
/>
</div> </div>
) )
} }

View File

@ -56,7 +56,7 @@ export default function ShortcutsModal() {
/> />
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} /> <ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
<ShortCut content="Pan" keys={['Space & Drag']} /> <ShortCut content="Pan" keys={['Space & Drag']} />
<ShortCut content="View Original Image" keys={['Hold Tag']} /> <ShortCut content="View Original Image" keys={['Hold Tab']} />
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} /> <ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} /> <ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} /> <ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />

View File

@ -2,8 +2,9 @@
all: unset; all: unset;
flex: 1 0 auto; flex: 1 0 auto;
border-radius: 0.5rem; border-radius: 0.5rem;
padding: 0.4rem 0.8rem; padding: 0 0.8rem;
outline: 1px solid var(--border-color); outline: 1px solid var(--border-color);
height: 36px;
&:focus-visible { &:focus-visible {
outline: 1px solid var(--yellow-accent); outline: 1px solid var(--yellow-accent);

View File

@ -1,5 +1,5 @@
import { atom } from 'recoil' import { atom } from 'recoil'
import { HDStrategy } from '../components/Settings/HDSettingBlock' import { HDStrategy, LDMSampler } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Settings/ModelSettingBlock' import { AIModel } from '../components/Settings/ModelSettingBlock'
import { ToastState } from '../components/shared/Toast' import { ToastState } from '../components/shared/Toast'
@ -43,6 +43,7 @@ export interface Settings {
// For LDM // For LDM
ldmSteps: number ldmSteps: number
ldmSampler: LDMSampler
} }
export const settingStateDefault = { export const settingStateDefault = {
@ -50,6 +51,7 @@ export const settingStateDefault = {
runInpaintingManually: false, runInpaintingManually: false,
model: AIModel.LAMA, model: AIModel.LAMA,
ldmSteps: 50, ldmSteps: 50,
ldmSampler: LDMSampler.plms,
hdStrategy: HDStrategy.RESIZE, hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048, hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048, hdStrategyCropTrigerSize: 2048,

View File

@ -0,0 +1,193 @@
import torch
import numpy as np
from tqdm import tqdm
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
from loguru import logger
class DDIMSampler(object):
def __init__(self, model, schedule="linear"):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
# array([1])
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(self, steps, conditioning, batch_size, shape):
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# samples: 1,3,128,128
return self.ddim_sampling(
conditioning,
size,
quantize_denoised=False,
ddim_use_original_steps=False,
noise_dropout=0,
temperature=1.0,
)
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
ddim_use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device, dtype=cond.dtype)
timesteps = (
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
)
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
)
img, _ = outs
return img
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: # 没用
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0: # 没用
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0

View File

@ -5,17 +5,15 @@ import torch
from loguru import logger from loguru import logger
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.model.ddim_sampler import DDIMSampler
from lama_cleaner.model.plms_sampler import PLMSSampler
from lama_cleaner.schema import Config, LDMSampler
torch.manual_seed(42) torch.manual_seed(42)
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.utils import ( from lama_cleaner.model.utils import (
make_beta_schedule, make_beta_schedule,
make_ddim_timesteps,
make_ddim_sampling_parameters,
noise_like,
timestep_embedding, timestep_embedding,
) )
@ -94,7 +92,7 @@ class DDPM(nn.Module):
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
assert ( assert (
alphas_cumprod.shape[0] == self.num_timesteps alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep" ), "alphas have to be defined for each timestep"
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
@ -120,7 +118,7 @@ class DDPM(nn.Module):
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * ( posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev 1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance)) self.register_buffer("posterior_variance", to_torch(posterior_variance))
@ -142,16 +140,16 @@ class DDPM(nn.Module):
if self.parameterization == "eps": if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / ( lvlb_weights = self.betas ** 2 / (
2 2
* self.posterior_variance * self.posterior_variance
* to_torch(alphas) * to_torch(alphas)
* (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod)
) )
elif self.parameterization == "x0": elif self.parameterization == "x0":
lvlb_weights = ( lvlb_weights = (
0.5 0.5
* np.sqrt(torch.Tensor(alphas_cumprod)) * np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
) )
else: else:
raise NotImplementedError("mu not supported") raise NotImplementedError("mu not supported")
@ -221,192 +219,6 @@ class LatentDiffusion(DDPM):
return x_recon return x_recon
class DDIMSampler(object):
def __init__(self, model, schedule="linear"):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
# array([1])
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(self, steps, conditioning, batch_size, shape):
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# samples: 1,3,128,128
return self.ddim_sampling(
conditioning,
size,
quantize_denoised=False,
ddim_use_original_steps=False,
noise_dropout=0,
temperature=1.0,
)
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
ddim_use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device, dtype=cond.dtype)
timesteps = (
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
)
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
)
img, _ = outs
return img
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: # 没用
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0: # 没用
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
def load_jit_model(url, device): def load_jit_model(url, device):
model_path = download_model(url) model_path = download_model(url)
logger.info(f"Load LDM model from: {model_path}") logger.info(f"Load LDM model from: {model_path}")
@ -432,8 +244,7 @@ class LDM(InpaintModel):
self.cond_stage_model_decode = self.cond_stage_model_decode.half() self.cond_stage_model_decode = self.cond_stage_model_decode.half()
self.cond_stage_model_encode = self.cond_stage_model_encode.half() self.cond_stage_model_encode = self.cond_stage_model_encode.half()
model = LatentDiffusion(self.diffusion_model, device) self.model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:
@ -454,6 +265,13 @@ class LDM(InpaintModel):
# image [1,3,512,512] float32 # image [1,3,512,512] float32
# mask: [1,1,512,512] float32 # mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32 # masked_image: [1,3,512,512] float32
if config.ldm_sampler == LDMSampler.ddim:
sampler = DDIMSampler(self.model)
elif config.ldm_sampler == LDMSampler.plms:
sampler = PLMSSampler(self.model)
else:
raise ValueError()
steps = config.ldm_steps steps = config.ldm_steps
image = norm_img(image) image = norm_img(image)
mask = norm_img(mask) mask = norm_img(mask)
@ -465,7 +283,6 @@ class LDM(InpaintModel):
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image masked_image = (1 - mask) * image
image = self._norm(image)
mask = self._norm(mask) mask = self._norm(mask)
masked_image = self._norm(masked_image) masked_image = self._norm(masked_image)
@ -476,7 +293,7 @@ class LDM(InpaintModel):
c = torch.cat((c, cc), dim=1) # 1,4,128,128 c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:] shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample( samples_ddim = sampler.sample(
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -0,0 +1,225 @@
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
import torch
import numpy as np
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
from tqdm import tqdm
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
steps,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
return img
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -9,8 +9,14 @@ class HDStrategy(str, Enum):
CROP = 'Crop' CROP = 'Crop'
class LDMSampler(str, Enum):
ddim = 'ddim'
plms = 'plms'
class Config(BaseModel): class Config(BaseModel):
ldm_steps: int ldm_steps: int
ldm_sampler: str
hd_strategy: str hd_strategy: str
hd_strategy_crop_margin: int hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int hd_strategy_crop_trigger_size: int

View File

@ -93,6 +93,7 @@ def process():
config = Config( config = Config(
ldm_steps=form["ldmSteps"], ldm_steps=form["ldmSteps"],
ldm_sampler=form["ldmSampler"],
hd_strategy=form["hdStrategy"], hd_strategy=form["hdStrategy"],
hd_strategy_crop_margin=form["hdStrategyCropMargin"], hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],

View File

@ -21,7 +21,7 @@ def load_requirements():
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files # https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup( setuptools.setup(
name="lama-cleaner", name="lama-cleaner",
version="0.12.0", version="0.13.0",
author="PanicByte", author="PanicByte",
author_email="cwq1913@gmail.com", author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model", description="Image inpainting tool powered by SOTA AI Model",