commit
d43d0694c2
34
README.md
34
README.md
@ -36,10 +36,40 @@ Available commands:
|
|||||||
| --port | Port for flask web server | 8080 |
|
| --port | Port for flask web server | 8080 |
|
||||||
| --debug | Enable debug mode for flask web server | |
|
| --debug | Enable debug mode for flask web server | |
|
||||||
|
|
||||||
|
## Settings
|
||||||
|
|
||||||
|
You can change the configs of inpainting process in the settings interface of the web page.
|
||||||
|
|
||||||
|
<img src="./assets/settings.png" width="400px">
|
||||||
|
|
||||||
|
### Inpainting Model
|
||||||
|
|
||||||
|
Select the inpainting model to use, and set the configs corresponding to the model.
|
||||||
|
|
||||||
|
LaMa model has no configs that can be specified at runtime.
|
||||||
|
|
||||||
|
LDM model has two configs to control the quality of final result:
|
||||||
|
1. Steps: You can get better result with large steps, but it will be more time-consuming
|
||||||
|
2. Sampler: ddim or [plms](https://arxiv.org/abs/2202.09778). In general plms can get better results with fewer steps
|
||||||
|
|
||||||
|
|
||||||
|
### High Resolution Strategy
|
||||||
|
|
||||||
|
There are three strategies for handling high-resolution images.
|
||||||
|
|
||||||
|
- **Original**: Use the original resolution of the picture, suitable for picture size below 2K.
|
||||||
|
- **Resize**: Resize the longer side of the image to a specific size(keep ratio), then do inpainting on the resized image.
|
||||||
|
The inpainting result will be pasted back on the original image to make sure other part of image not loss quality.
|
||||||
|
- **Crop**: Crop masking area from the original image to do inpainting, and paste the result back.
|
||||||
|
Mainly for performance and memory reasons on high resolution image. This strategy may give better results for ldm model.
|
||||||
|
|
||||||
|
|
||||||
## Model Comparison
|
## Model Comparison
|
||||||
|
|
||||||
Diffusion model(ldm) is **MUCH MORE** slower than GANs(lama)(1080x720 image takes 8s on 3090), but it's possible to get better
|
| Model | Pron | Corn |
|
||||||
result, see below example:
|
|-------|-----------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
|
||||||
|
| LaMa | - Perform will on high resolution image(~2k)<br/> - Faster than diffusion model | |
|
||||||
|
| LDM | - It's possible to get better and more detail result, see example below<br/> - The balance of time and quality can be achieved by steps | - Slower than GAN model<br/> - Need more GPU memory<br/> - Not good for high resolution images |
|
||||||
|
|
||||||
| Original Image | LaMa | LDM |
|
| Original Image | LaMa | LDM |
|
||||||
| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
BIN
assets/settings.png
Normal file
BIN
assets/settings.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 80 KiB |
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"files": {
|
"files": {
|
||||||
"main.css": "/static/css/main.d1028b29.chunk.css",
|
"main.css": "/static/css/main.bdd774f1.chunk.css",
|
||||||
"main.js": "/static/js/main.2711860b.chunk.js",
|
"main.js": "/static/js/main.9f4ed6dc.chunk.js",
|
||||||
"runtime-main.js": "/static/js/runtime-main.5e86ac81.js",
|
"runtime-main.js": "/static/js/runtime-main.5e86ac81.js",
|
||||||
"static/js/2.1b1d3019.chunk.js": "/static/js/2.1b1d3019.chunk.js",
|
"static/js/2.1b1d3019.chunk.js": "/static/js/2.1b1d3019.chunk.js",
|
||||||
"index.html": "/index.html",
|
"index.html": "/index.html",
|
||||||
@ -11,7 +11,7 @@
|
|||||||
"entrypoints": [
|
"entrypoints": [
|
||||||
"static/js/runtime-main.5e86ac81.js",
|
"static/js/runtime-main.5e86ac81.js",
|
||||||
"static/js/2.1b1d3019.chunk.js",
|
"static/js/2.1b1d3019.chunk.js",
|
||||||
"static/css/main.d1028b29.chunk.css",
|
"static/css/main.bdd774f1.chunk.css",
|
||||||
"static/js/main.2711860b.chunk.js"
|
"static/js/main.9f4ed6dc.chunk.js"
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -1 +1 @@
|
|||||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.d1028b29.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.1b1d3019.chunk.js"></script><script src="/static/js/main.2711860b.chunk.js"></script></body></html>
|
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.bdd774f1.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.1b1d3019.chunk.js"></script><script src="/static/js/main.9f4ed6dc.chunk.js"></script></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
lama_cleaner/app/build/static/js/main.9f4ed6dc.chunk.js
Normal file
1
lama_cleaner/app/build/static/js/main.9f4ed6dc.chunk.js
Normal file
File diff suppressed because one or more lines are too long
@ -16,6 +16,7 @@ export default async function inpaint(
|
|||||||
fd.append('mask', mask)
|
fd.append('mask', mask)
|
||||||
|
|
||||||
fd.append('ldmSteps', settings.ldmSteps.toString())
|
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||||
|
fd.append('ldmSampler', settings.ldmSampler.toString())
|
||||||
fd.append('hdStrategy', settings.hdStrategy)
|
fd.append('hdStrategy', settings.hdStrategy)
|
||||||
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
||||||
fd.append(
|
fd.append(
|
||||||
@ -34,7 +35,11 @@ export default async function inpaint(
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: fd,
|
body: fd,
|
||||||
}).then(async r => {
|
}).then(async r => {
|
||||||
return r.blob()
|
console.log(r)
|
||||||
|
if (r.ok) {
|
||||||
|
return r.blob()
|
||||||
|
}
|
||||||
|
throw new Error('Something went wrong on server side.')
|
||||||
})
|
})
|
||||||
|
|
||||||
return URL.createObjectURL(res)
|
return URL.createObjectURL(res)
|
||||||
|
@ -15,7 +15,7 @@ import {
|
|||||||
TransformComponent,
|
TransformComponent,
|
||||||
TransformWrapper,
|
TransformWrapper,
|
||||||
} from 'react-zoom-pan-pinch'
|
} from 'react-zoom-pan-pinch'
|
||||||
import { useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||||
import inpaint from '../../adapters/inpainting'
|
import inpaint from '../../adapters/inpainting'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
@ -28,7 +28,7 @@ import {
|
|||||||
loadImage,
|
loadImage,
|
||||||
useImage,
|
useImage,
|
||||||
} from '../../utils'
|
} from '../../utils'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState, toastState } from '../../store/Atoms'
|
||||||
|
|
||||||
const TOOLBAR_SIZE = 200
|
const TOOLBAR_SIZE = 200
|
||||||
const BRUSH_COLOR = '#ffcc00bb'
|
const BRUSH_COLOR = '#ffcc00bb'
|
||||||
@ -73,6 +73,7 @@ function mouseXY(ev: SyntheticEvent) {
|
|||||||
export default function Editor(props: EditorProps) {
|
export default function Editor(props: EditorProps) {
|
||||||
const { file } = props
|
const { file } = props
|
||||||
const settings = useRecoilValue(settingState)
|
const settings = useRecoilValue(settingState)
|
||||||
|
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||||
const [brushSize, setBrushSize] = useState(40)
|
const [brushSize, setBrushSize] = useState(40)
|
||||||
const [original, isOriginalLoaded] = useImage(file)
|
const [original, isOriginalLoaded] = useImage(file)
|
||||||
const [renders, setRenders] = useState<HTMLImageElement[]>([])
|
const [renders, setRenders] = useState<HTMLImageElement[]>([])
|
||||||
@ -144,12 +145,11 @@ export default function Editor(props: EditorProps) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const newLineGroups = [...lineGroups, curLineGroup]
|
const newLineGroups = [...lineGroups, curLineGroup]
|
||||||
setLineGroups(newLineGroups)
|
|
||||||
setCurLineGroup([])
|
setCurLineGroup([])
|
||||||
setIsDraging(false)
|
setIsDraging(false)
|
||||||
setIsInpaintingLoading(true)
|
setIsInpaintingLoading(true)
|
||||||
|
|
||||||
drawAllLinesOnMask(newLineGroups)
|
drawAllLinesOnMask(newLineGroups)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const res = await inpaint(
|
const res = await inpaint(
|
||||||
file,
|
file,
|
||||||
@ -165,9 +165,16 @@ export default function Editor(props: EditorProps) {
|
|||||||
const newRenders = [...renders, newRender]
|
const newRenders = [...renders, newRender]
|
||||||
setRenders(newRenders)
|
setRenders(newRenders)
|
||||||
draw(newRender, [])
|
draw(newRender, [])
|
||||||
|
// Only append new LineGroup after inpainting success
|
||||||
|
setLineGroups(newLineGroups)
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
// eslint-disable-next-line
|
setToastState({
|
||||||
alert(e.message ? e.message : e.toString())
|
open: true,
|
||||||
|
desc: e.message ? e.message : e.toString(),
|
||||||
|
state: 'error',
|
||||||
|
duration: 2000,
|
||||||
|
})
|
||||||
|
drawOnCurrentRender([])
|
||||||
}
|
}
|
||||||
setIsInpaintingLoading(false)
|
setIsInpaintingLoading(false)
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,11 @@ export enum HDStrategy {
|
|||||||
CROP = 'Crop',
|
CROP = 'Crop',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum LDMSampler {
|
||||||
|
ddim = 'ddim',
|
||||||
|
plms = 'plms',
|
||||||
|
}
|
||||||
|
|
||||||
function HDSettingBlock() {
|
function HDSettingBlock() {
|
||||||
const [setting, setSettingState] = useRecoilState(settingState)
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import React, { ReactNode } from 'react'
|
|||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import { settingState } from '../../store/Atoms'
|
import { settingState } from '../../store/Atoms'
|
||||||
import Selector from '../shared/Selector'
|
import Selector from '../shared/Selector'
|
||||||
|
import { LDMSampler } from './HDSettingBlock'
|
||||||
import NumberInputSetting from './NumberInputSetting'
|
import NumberInputSetting from './NumberInputSetting'
|
||||||
import SettingBlock from './SettingBlock'
|
import SettingBlock from './SettingBlock'
|
||||||
|
|
||||||
@ -19,6 +20,12 @@ function ModelSettingBlock() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onLDMSamplerChange = (value: LDMSampler) => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, ldmSampler: value }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const renderModelDesc = (
|
const renderModelDesc = (
|
||||||
name: string,
|
name: string,
|
||||||
paperUrl: string,
|
paperUrl: string,
|
||||||
@ -65,6 +72,19 @@ function ModelSettingBlock() {
|
|||||||
})
|
})
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
className="sub-setting-block"
|
||||||
|
title="Sampler"
|
||||||
|
input={
|
||||||
|
<Selector
|
||||||
|
width={80}
|
||||||
|
value={setting.ldmSampler as string}
|
||||||
|
options={Object.values(LDMSampler)}
|
||||||
|
onChange={val => onLDMSamplerChange(val as LDMSampler)}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ export default function ShortcutsModal() {
|
|||||||
/>
|
/>
|
||||||
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
|
||||||
<ShortCut content="Pan" keys={['Space & Drag']} />
|
<ShortCut content="Pan" keys={['Space & Drag']} />
|
||||||
<ShortCut content="View Original Image" keys={['Hold Tag']} />
|
<ShortCut content="View Original Image" keys={['Hold Tab']} />
|
||||||
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
|
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
|
||||||
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
|
||||||
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
all: unset;
|
all: unset;
|
||||||
flex: 1 0 auto;
|
flex: 1 0 auto;
|
||||||
border-radius: 0.5rem;
|
border-radius: 0.5rem;
|
||||||
padding: 0.4rem 0.8rem;
|
padding: 0 0.8rem;
|
||||||
outline: 1px solid var(--border-color);
|
outline: 1px solid var(--border-color);
|
||||||
|
height: 36px;
|
||||||
|
|
||||||
&:focus-visible {
|
&:focus-visible {
|
||||||
outline: 1px solid var(--yellow-accent);
|
outline: 1px solid var(--yellow-accent);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { atom } from 'recoil'
|
import { atom } from 'recoil'
|
||||||
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
import { HDStrategy, LDMSampler } from '../components/Settings/HDSettingBlock'
|
||||||
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||||
import { ToastState } from '../components/shared/Toast'
|
import { ToastState } from '../components/shared/Toast'
|
||||||
|
|
||||||
@ -43,6 +43,7 @@ export interface Settings {
|
|||||||
|
|
||||||
// For LDM
|
// For LDM
|
||||||
ldmSteps: number
|
ldmSteps: number
|
||||||
|
ldmSampler: LDMSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
export const settingStateDefault = {
|
export const settingStateDefault = {
|
||||||
@ -50,6 +51,7 @@ export const settingStateDefault = {
|
|||||||
runInpaintingManually: false,
|
runInpaintingManually: false,
|
||||||
model: AIModel.LAMA,
|
model: AIModel.LAMA,
|
||||||
ldmSteps: 50,
|
ldmSteps: 50,
|
||||||
|
ldmSampler: LDMSampler.plms,
|
||||||
hdStrategy: HDStrategy.RESIZE,
|
hdStrategy: HDStrategy.RESIZE,
|
||||||
hdStrategyResizeLimit: 2048,
|
hdStrategyResizeLimit: 2048,
|
||||||
hdStrategyCropTrigerSize: 2048,
|
hdStrategyCropTrigerSize: 2048,
|
||||||
|
193
lama_cleaner/model/ddim_sampler.py
Normal file
193
lama_cleaner/model/ddim_sampler.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear"):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(
|
||||||
|
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||||
|
):
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
|
ddim_discr_method=ddim_discretize,
|
||||||
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
|
# array([1])
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
||||||
|
assert (
|
||||||
|
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||||
|
), "alphas have to be defined for each timestep"
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer("betas", to_torch(self.model.betas))
|
||||||
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer(
|
||||||
|
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_one_minus_alphas_cumprod",
|
||||||
|
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recipm1_alphas_cumprod",
|
||||||
|
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||||
|
alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||||
|
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||||
|
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||||
|
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev)
|
||||||
|
/ (1 - self.alphas_cumprod)
|
||||||
|
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, steps, conditioning, batch_size, shape):
|
||||||
|
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
|
||||||
|
# samples: 1,3,128,128
|
||||||
|
return self.ddim_sampling(
|
||||||
|
conditioning,
|
||||||
|
size,
|
||||||
|
quantize_denoised=False,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=0,
|
||||||
|
temperature=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sampling(
|
||||||
|
self,
|
||||||
|
cond,
|
||||||
|
shape,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
||||||
|
timesteps = (
|
||||||
|
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
time_range = (
|
||||||
|
reversed(range(0, timesteps))
|
||||||
|
if ddim_use_original_steps
|
||||||
|
else np.flip(timesteps)
|
||||||
|
)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
outs = self.p_sample_ddim(
|
||||||
|
img,
|
||||||
|
cond,
|
||||||
|
ts,
|
||||||
|
index=index,
|
||||||
|
use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised,
|
||||||
|
temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
)
|
||||||
|
img, _ = outs
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
t,
|
||||||
|
index,
|
||||||
|
repeat_noise=False,
|
||||||
|
use_original_steps=False,
|
||||||
|
quantize_denoised=False,
|
||||||
|
temperature=1.0,
|
||||||
|
noise_dropout=0.0,
|
||||||
|
):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = (
|
||||||
|
self.model.alphas_cumprod_prev
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_alphas_prev
|
||||||
|
)
|
||||||
|
sqrt_one_minus_alphas = (
|
||||||
|
self.model.sqrt_one_minus_alphas_cumprod
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sqrt_one_minus_alphas
|
||||||
|
)
|
||||||
|
sigmas = (
|
||||||
|
self.model.ddim_sigmas_for_original_num_steps
|
||||||
|
if use_original_steps
|
||||||
|
else self.ddim_sigmas
|
||||||
|
)
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full(
|
||||||
|
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised: # 没用
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.0: # 没用
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
@ -5,17 +5,15 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
||||||
|
from lama_cleaner.model.plms_sampler import PLMSSampler
|
||||||
|
from lama_cleaner.schema import Config, LDMSampler
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
|
||||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||||
from lama_cleaner.model.utils import (
|
from lama_cleaner.model.utils import (
|
||||||
make_beta_schedule,
|
make_beta_schedule,
|
||||||
make_ddim_timesteps,
|
|
||||||
make_ddim_sampling_parameters,
|
|
||||||
noise_like,
|
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,7 +92,7 @@ class DDPM(nn.Module):
|
|||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
assert (
|
assert (
|
||||||
alphas_cumprod.shape[0] == self.num_timesteps
|
alphas_cumprod.shape[0] == self.num_timesteps
|
||||||
), "alphas have to be defined for each timestep"
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
||||||
@ -120,7 +118,7 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = (1 - self.v_posterior) * betas * (
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
||||||
1.0 - alphas_cumprod_prev
|
1.0 - alphas_cumprod_prev
|
||||||
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
||||||
@ -142,16 +140,16 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
if self.parameterization == "eps":
|
if self.parameterization == "eps":
|
||||||
lvlb_weights = self.betas ** 2 / (
|
lvlb_weights = self.betas ** 2 / (
|
||||||
2
|
2
|
||||||
* self.posterior_variance
|
* self.posterior_variance
|
||||||
* to_torch(alphas)
|
* to_torch(alphas)
|
||||||
* (1 - self.alphas_cumprod)
|
* (1 - self.alphas_cumprod)
|
||||||
)
|
)
|
||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
lvlb_weights = (
|
lvlb_weights = (
|
||||||
0.5
|
0.5
|
||||||
* np.sqrt(torch.Tensor(alphas_cumprod))
|
* np.sqrt(torch.Tensor(alphas_cumprod))
|
||||||
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("mu not supported")
|
raise NotImplementedError("mu not supported")
|
||||||
@ -221,192 +219,6 @@ class LatentDiffusion(DDPM):
|
|||||||
return x_recon
|
return x_recon
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
|
||||||
def __init__(self, model, schedule="linear"):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
|
||||||
self.schedule = schedule
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(
|
|
||||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
|
||||||
):
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
|
||||||
ddim_discr_method=ddim_discretize,
|
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
|
||||||
# array([1])
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
|
||||||
assert (
|
|
||||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
|
||||||
), "alphas have to be defined for each timestep"
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
||||||
|
|
||||||
self.register_buffer("betas", to_torch(self.model.betas))
|
|
||||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer(
|
|
||||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_one_minus_alphas_cumprod",
|
|
||||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recipm1_alphas_cumprod",
|
|
||||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ddim sampling parameters
|
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
||||||
alphacums=alphas_cumprod.cpu(),
|
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
|
||||||
eta=ddim_eta,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
|
||||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
|
||||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
|
||||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - self.alphas_cumprod_prev)
|
|
||||||
/ (1 - self.alphas_cumprod)
|
|
||||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self, steps, conditioning, batch_size, shape):
|
|
||||||
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
|
|
||||||
# samples: 1,3,128,128
|
|
||||||
return self.ddim_sampling(
|
|
||||||
conditioning,
|
|
||||||
size,
|
|
||||||
quantize_denoised=False,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=0,
|
|
||||||
temperature=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ddim_sampling(
|
|
||||||
self,
|
|
||||||
cond,
|
|
||||||
shape,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
quantize_denoised=False,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
):
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = shape[0]
|
|
||||||
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
|
||||||
timesteps = (
|
|
||||||
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
|
||||||
)
|
|
||||||
|
|
||||||
time_range = (
|
|
||||||
reversed(range(0, timesteps))
|
|
||||||
if ddim_use_original_steps
|
|
||||||
else np.flip(timesteps)
|
|
||||||
)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
outs = self.p_sample_ddim(
|
|
||||||
img,
|
|
||||||
cond,
|
|
||||||
ts,
|
|
||||||
index=index,
|
|
||||||
use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
)
|
|
||||||
img, _ = outs
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_ddim(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
c,
|
|
||||||
t,
|
|
||||||
index,
|
|
||||||
repeat_noise=False,
|
|
||||||
use_original_steps=False,
|
|
||||||
quantize_denoised=False,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = (
|
|
||||||
self.model.alphas_cumprod_prev
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_alphas_prev
|
|
||||||
)
|
|
||||||
sqrt_one_minus_alphas = (
|
|
||||||
self.model.sqrt_one_minus_alphas_cumprod
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_sqrt_one_minus_alphas
|
|
||||||
)
|
|
||||||
sigmas = (
|
|
||||||
self.model.ddim_sigmas_for_original_num_steps
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_sigmas
|
|
||||||
)
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full(
|
|
||||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised: # 没用
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.0: # 没用
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
|
|
||||||
def load_jit_model(url, device):
|
def load_jit_model(url, device):
|
||||||
model_path = download_model(url)
|
model_path = download_model(url)
|
||||||
logger.info(f"Load LDM model from: {model_path}")
|
logger.info(f"Load LDM model from: {model_path}")
|
||||||
@ -432,8 +244,7 @@ class LDM(InpaintModel):
|
|||||||
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
||||||
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
||||||
|
|
||||||
model = LatentDiffusion(self.diffusion_model, device)
|
self.model = LatentDiffusion(self.diffusion_model, device)
|
||||||
self.sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
@ -454,6 +265,13 @@ class LDM(InpaintModel):
|
|||||||
# image [1,3,512,512] float32
|
# image [1,3,512,512] float32
|
||||||
# mask: [1,1,512,512] float32
|
# mask: [1,1,512,512] float32
|
||||||
# masked_image: [1,3,512,512] float32
|
# masked_image: [1,3,512,512] float32
|
||||||
|
if config.ldm_sampler == LDMSampler.ddim:
|
||||||
|
sampler = DDIMSampler(self.model)
|
||||||
|
elif config.ldm_sampler == LDMSampler.plms:
|
||||||
|
sampler = PLMSSampler(self.model)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
steps = config.ldm_steps
|
steps = config.ldm_steps
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
@ -465,7 +283,6 @@ class LDM(InpaintModel):
|
|||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||||
masked_image = (1 - mask) * image
|
masked_image = (1 - mask) * image
|
||||||
|
|
||||||
image = self._norm(image)
|
|
||||||
mask = self._norm(mask)
|
mask = self._norm(mask)
|
||||||
masked_image = self._norm(masked_image)
|
masked_image = self._norm(masked_image)
|
||||||
|
|
||||||
@ -476,7 +293,7 @@ class LDM(InpaintModel):
|
|||||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||||
|
|
||||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||||
samples_ddim = self.sampler.sample(
|
samples_ddim = sampler.sample(
|
||||||
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
225
lama_cleaner/model/plms_sampler.py
Normal file
225
lama_cleaner/model/plms_sampler.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
if ddim_eta != 0:
|
||||||
|
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta, verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
steps,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=False,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps, t_next=ts_next)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
@ -9,8 +9,14 @@ class HDStrategy(str, Enum):
|
|||||||
CROP = 'Crop'
|
CROP = 'Crop'
|
||||||
|
|
||||||
|
|
||||||
|
class LDMSampler(str, Enum):
|
||||||
|
ddim = 'ddim'
|
||||||
|
plms = 'plms'
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
ldm_steps: int
|
ldm_steps: int
|
||||||
|
ldm_sampler: str
|
||||||
hd_strategy: str
|
hd_strategy: str
|
||||||
hd_strategy_crop_margin: int
|
hd_strategy_crop_margin: int
|
||||||
hd_strategy_crop_trigger_size: int
|
hd_strategy_crop_trigger_size: int
|
||||||
|
@ -93,6 +93,7 @@ def process():
|
|||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
ldm_steps=form["ldmSteps"],
|
ldm_steps=form["ldmSteps"],
|
||||||
|
ldm_sampler=form["ldmSampler"],
|
||||||
hd_strategy=form["hdStrategy"],
|
hd_strategy=form["hdStrategy"],
|
||||||
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
|
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
|
||||||
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
|
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
|
||||||
|
2
setup.py
2
setup.py
@ -21,7 +21,7 @@ def load_requirements():
|
|||||||
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="lama-cleaner",
|
name="lama-cleaner",
|
||||||
version="0.12.0",
|
version="0.13.0",
|
||||||
author="PanicByte",
|
author="PanicByte",
|
||||||
author_email="cwq1913@gmail.com",
|
author_email="cwq1913@gmail.com",
|
||||||
description="Image inpainting tool powered by SOTA AI Model",
|
description="Image inpainting tool powered by SOTA AI Model",
|
||||||
|
Loading…
Reference in New Issue
Block a user