Merge pull request #38 from Sanster/0409_optimize

0409 optimize
This commit is contained in:
Qing 2022-04-18 22:32:18 +08:00 committed by GitHub
commit b43883a567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
75 changed files with 2123 additions and 555 deletions

4
.gitignore vendored
View File

@ -3,3 +3,7 @@
examples/
.idea/
.vscode/
build
!lama_cleaner/app/build
dist/
lama_cleaner.egg-info/

View File

@ -18,18 +18,15 @@ https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b
Available commands for `main.py`
| Name | Description | Default |
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------- |
| --model | lama or ldm. See details in **Model Comparison** | lama |
| --device | cuda or cpu | cuda |
| --ldm-steps | The larger the value, the better the result, but it will be more time-consuming | 50 |
| --crop-trigger-size | If image size large then crop-trigger-size, crop each area from original image to do inference. Mainly for performance and memory reasons on **very** large image. | 2042,2042 |
| --crop-margin | Margin around bounding box of painted stroke when crop mode triggered. | 256 |
| --gui | Launch lama-cleaner as a desktop application | |
| --gui_size | Set the window size for the application | 1200 900 |
| --input | Path to image you want to load by default | None |
| --port | Port for flask web server | 8080 |
| --debug | Enable debug mode for flask web server | |
| Name | Description | Default |
| ---------- | ------------------------------------------------ | -------- |
| --model | lama or ldm. See details in **Model Comparison** | lama |
| --device | cuda or cpu | cuda |
| --gui | Launch lama-cleaner as a desktop application | |
| --gui_size | Set the window size for the application | 1200 900 |
| --input | Path to image you want to load by default | None |
| --port | Port for flask web server | 8080 |
| --debug | Enable debug mode for flask web server | |
## Model Comparison

7
lama_cleaner/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from lama_cleaner.parse_args import parse_args
from lama_cleaner.server import main
def entry_point():
args = parse_args()
main(args)

View File

@ -17,6 +17,8 @@
"project": "./tsconfig.json"
},
"rules": {
"jsx-a11y/click-events-have-key-events": 0,
"react/jsx-props-no-spreading": 0,
"import/no-unresolved": 0,
"react/jsx-no-bind": "off",
"react/jsx-filename-extension": [

View File

@ -1,17 +1,17 @@
{
"files": {
"main.css": "/static/css/main.fd853425.chunk.css",
"main.js": "/static/js/main.c06ba56c.chunk.js",
"main.css": "/static/css/main.5c7abe60.chunk.css",
"main.js": "/static/js/main.0a74f667.chunk.js",
"runtime-main.js": "/static/js/runtime-main.5e86ac81.js",
"static/js/2.97604ba9.chunk.js": "/static/js/2.97604ba9.chunk.js",
"static/js/2.4cb726d6.chunk.js": "/static/js/2.4cb726d6.chunk.js",
"index.html": "/index.html",
"static/js/2.97604ba9.chunk.js.LICENSE.txt": "/static/js/2.97604ba9.chunk.js.LICENSE.txt",
"static/js/2.4cb726d6.chunk.js.LICENSE.txt": "/static/js/2.4cb726d6.chunk.js.LICENSE.txt",
"static/media/_index.scss": "/static/media/WorkSans-SemiBold.1e98db4e.ttf"
},
"entrypoints": [
"static/js/runtime-main.5e86ac81.js",
"static/js/2.97604ba9.chunk.js",
"static/css/main.fd853425.chunk.css",
"static/js/main.c06ba56c.chunk.js"
"static/js/2.4cb726d6.chunk.js",
"static/css/main.5c7abe60.chunk.css",
"static/js/main.0a74f667.chunk.js"
]
}

View File

@ -1 +1 @@
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.fd853425.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>"localhost"===location.hostname&&(self.FIREBASE_APPCHECK_DEBUG_TOKEN=!0)</script><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.97604ba9.chunk.js"></script><script src="/static/js/main.c06ba56c.chunk.js"></script></body></html>
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.5c7abe60.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>"localhost"===location.hostname&&(self.FIREBASE_APPCHECK_DEBUG_TOKEN=!0)</script><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.4cb726d6.chunk.js"></script><script src="/static/js/main.0a74f667.chunk.js"></script></body></html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,6 +5,8 @@
"proxy": "http://localhost:8080",
"dependencies": {
"@heroicons/react": "^1.0.4",
"@radix-ui/react-switch": "^0.1.5",
"@radix-ui/react-toast": "^0.1.1",
"@testing-library/jest-dom": "^5.14.1",
"@testing-library/react": "^12.1.2",
"@testing-library/user-event": "^13.5.0",

View File

@ -1,10 +1,12 @@
import { Settings } from '../store/Atoms'
import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
export default async function inpaint(
imageFile: File,
maskBase64: string,
settings: Settings,
sizeLimit?: string
) {
// 1080, 2000, Original
@ -13,13 +15,22 @@ export default async function inpaint(
const mask = dataURItoBlob(maskBase64)
fd.append('mask', mask)
fd.append('ldmSteps', settings.ldmSteps.toString())
fd.append('hdStrategy', settings.hdStrategy)
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
fd.append(
'hdStrategyCropTrigerSize',
settings.hdStrategyCropTrigerSize.toString()
)
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {
fd.append('sizeLimit', sizeLimit)
}
const res = await fetch(API_ENDPOINT, {
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
method: 'POST',
body: fd,
}).then(async r => {
@ -28,3 +39,24 @@ export default async function inpaint(
return URL.createObjectURL(res)
}
export function switchModel(name: string) {
const fd = new FormData()
fd.append('name', name)
return fetch(`${API_ENDPOINT}/model`, {
method: 'POST',
body: fd,
})
}
export function currentModel() {
return fetch(`${API_ENDPOINT}/model`, {
method: 'GET',
})
}
export function modelDownloaded(name: string) {
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
method: 'GET',
})
}

View File

@ -15,12 +15,14 @@ import {
TransformComponent,
TransformWrapper,
} from 'react-zoom-pan-pinch'
import { useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint from '../../adapters/inpainting'
import Button from '../shared/Button'
import Slider from './Slider'
import SizeSelector from './SizeSelector'
import { downloadImage, loadImage, useImage } from '../../utils'
import { settingState } from '../../store/Atoms'
const TOOLBAR_SIZE = 200
const BRUSH_COLOR = '#ffcc00bb'
@ -57,6 +59,7 @@ function drawLines(
export default function Editor(props: EditorProps) {
const { file } = props
const settings = useRecoilValue(settingState)
const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([])
@ -73,10 +76,11 @@ export default function Editor(props: EditorProps) {
const [showOriginal, setShowOriginal] = useState(false)
const [isInpaintingLoading, setIsInpaintingLoading] = useState(false)
const [scale, setScale] = useState<number>(1)
const [minScale, setMinScale] = useState<number>()
const [minScale, setMinScale] = useState<number>(1.0)
const [sizeLimit, setSizeLimit] = useState<number>(1080)
const windowSize = useWindowSize()
const viewportRef = useRef<ReactZoomPanPinchRef | undefined | null>()
const [centered, setCentered] = useState(false)
const [isDraging, setIsDraging] = useState(false)
const [isMultiStrokeKeyPressed, setIsMultiStrokeKeyPressed] = useState(false)
@ -124,6 +128,7 @@ export default function Editor(props: EditorProps) {
const res = await inpaint(
file,
maskCanvas.toDataURL(),
settings,
sizeLimit.toString()
)
if (!res) {
@ -156,6 +161,7 @@ export default function Editor(props: EditorProps) {
renders,
sizeLimit,
historyLineCount,
settings,
])
const hadDrawSomething = () => {
@ -214,31 +220,38 @@ export default function Editor(props: EditorProps) {
// Draw once the original image is loaded
useEffect(() => {
if (!original) {
if (!isOriginalLoaded) {
return
}
if (isOriginalLoaded) {
const rW = windowSize.width / original.naturalWidth
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
if (rW < 1 || rH < 1) {
const s = Math.min(rW, rH)
setMinScale(s)
setScale(s)
} else {
setMinScale(1)
}
const rW = windowSize.width / original.naturalWidth
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
const imageSizeLimit = Math.max(original.width, original.height)
setSizeLimit(imageSizeLimit)
if (context?.canvas) {
context.canvas.width = original.naturalWidth
context.canvas.height = original.naturalHeight
}
draw()
let s = 1.0
if (rW < 1 || rH < 1) {
s = Math.min(rW, rH)
}
}, [context?.canvas, draw, original, isOriginalLoaded, windowSize])
setMinScale(s)
setScale(s)
const imageSizeLimit = Math.max(original.width, original.height)
setSizeLimit(imageSizeLimit)
if (context?.canvas) {
context.canvas.width = original.naturalWidth
context.canvas.height = original.naturalHeight
}
viewportRef.current?.centerView(s, 1)
setCentered(true)
draw()
}, [
context?.canvas,
draw,
viewportRef,
original,
isOriginalLoaded,
windowSize,
])
// Zoom reset
const resetZoom = useCallback(() => {
@ -522,10 +535,6 @@ export default function Editor(props: EditorProps) {
}
}
if (!original || !scale || !minScale) {
return <></>
}
return (
<div
className="editor-container"
@ -543,7 +552,7 @@ export default function Editor(props: EditorProps) {
wheel={{ step: 0.05 }}
centerZoomedOut
alignmentAnimation={{ disabled: true }}
centerOnInit
// centerOnInit
limitToBounds={false}
doubleClick={{ disabled: true }}
initialScale={minScale}
@ -554,6 +563,9 @@ export default function Editor(props: EditorProps) {
>
<TransformComponent
contentClass={isInpaintingLoading ? 'editor-canvas-loading' : ''}
contentStyle={{
visibility: centered ? 'visible' : 'hidden',
}}
>
<div className="editor-canvas-container">
<canvas

View File

@ -23,3 +23,11 @@ header {
gap: 12px;
justify-self: end;
}
.header-icons {
display: flex;
justify-content: center;
align-items: center;
gap: 6px;
justify-self: end;
}

View File

@ -6,6 +6,7 @@ import Button from '../shared/Button'
import Shortcuts from '../Shortcuts/Shortcuts'
import useResolution from '../../hooks/useResolution'
import { ThemeChanger } from './ThemeChanger'
import SettingIcon from '../Settings/SettingIcon'
const Header = () => {
const [file, setFile] = useRecoilState(fileState)
@ -26,7 +27,11 @@ const Header = () => {
</Button>
</div>
<div className="header-icons-wrapper">
<div style={{ visibility: file ? 'visible' : 'hidden' }}>
<div
className="header-icons"
style={{ visibility: file ? 'visible' : 'hidden' }}
>
<SettingIcon />
<Shortcuts />
</div>
<ThemeChanger />

View File

@ -0,0 +1,7 @@
.hd-setting-block {
.inline-tip {
display: inline;
cursor: pointer;
color: var(--text-color);
}
}

View File

@ -0,0 +1,132 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock'
export enum HDStrategy {
ORIGINAL = 'Original',
RESIZE = 'Resize',
CROP = 'Crop',
}
function HDSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
const onStrategyChange = (value: HDStrategy) => {
setSettingState(old => {
return { ...old, hdStrategy: value }
})
}
const onResizeLimitChange = (value: string) => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, hdStrategyResizeLimit: val }
})
}
const onCropTriggerSizeChange = (value: string) => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, hdStrategyCropTrigerSize: val }
})
}
const onCropMarginChange = (value: string) => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, hdStrategyCropMargin: val }
})
}
const renderOriginalOptionDesc = () => {
return (
<div>
Use the original resolution of the picture, suitable for picture size
below 2K. Try{' '}
<div
tabIndex={0}
role="button"
className="inline-tip"
onClick={() => onStrategyChange(HDStrategy.RESIZE)}
>
Resize Strategy
</div>{' '}
if you do not get good results on high resolution images.
</div>
)
}
const renderResizeOptionDesc = () => {
return (
<div>
<div>
Resize the longer side of the image to a specific size(keep ratio),
then do inpainting on the resized image.
</div>
<NumberInputSetting
title="Size limit"
value={`${setting.hdStrategyResizeLimit}`}
suffix="pixel"
onValue={onResizeLimitChange}
/>
</div>
)
}
const renderCropOptionDesc = () => {
return (
<div>
<div>
Crop masking area from the original image to do inpainting, and paste
the result back. Mainly for performance and memory reasons on high
resolution image.
</div>
<NumberInputSetting
title="Trigger size"
value={`${setting.hdStrategyCropTrigerSize}`}
suffix="pixel"
onValue={onCropTriggerSizeChange}
/>
<NumberInputSetting
title="Crop margin"
value={`${setting.hdStrategyCropMargin}`}
suffix="pixel"
onValue={onCropMarginChange}
/>
</div>
)
}
const renderHDStrategyOptionDesc = (): ReactNode => {
switch (setting.hdStrategy) {
case HDStrategy.ORIGINAL:
return renderOriginalOptionDesc()
case HDStrategy.CROP:
return renderCropOptionDesc()
case HDStrategy.RESIZE:
return renderResizeOptionDesc()
default:
return renderOriginalOptionDesc()
}
}
return (
<SettingBlock
className="hd-setting-block"
title="High Resolution Strategy"
input={
<Selector
value={setting.hdStrategy as string}
options={Object.values(HDStrategy)}
onChange={val => onStrategyChange(val as HDStrategy)}
/>
}
optionDesc={renderHDStrategyOptionDesc()}
/>
)
}
export default HDSettingBlock

View File

@ -0,0 +1,4 @@
.model-desc-link {
color: var(--text-color-gray);
text-decoration: none;
}

View File

@ -0,0 +1,103 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector'
import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock'
export enum AIModel {
LAMA = 'lama',
LDM = 'ldm',
}
function ModelSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
const onModelChange = (value: AIModel) => {
setSettingState(old => {
return { ...old, model: value }
})
}
const renderModelDesc = (
name: string,
paperUrl: string,
githubUrl: string
) => {
return (
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
<a
className="model-desc-link"
href={paperUrl}
target="_blank"
rel="noreferrer noopener"
>
{name}
</a>
<a
className="model-desc-link"
href={githubUrl}
target="_blank"
rel="noreferrer noopener"
>
{githubUrl}
</a>
</div>
)
}
const renderLDMModelDesc = () => {
return (
<div>
{renderModelDesc(
'High-Resolution Image Synthesis with Latent Diffusion Models',
'https://arxiv.org/abs/2112.10752',
'https://github.com/CompVis/latent-diffusion'
)}
<NumberInputSetting
title="Steps"
value={`${setting.ldmSteps}`}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, ldmSteps: val }
})
}}
/>
</div>
)
}
const renderOptionDesc = (): ReactNode => {
switch (setting.model) {
case AIModel.LAMA:
return renderModelDesc(
'Resolution-robust Large Mask Inpainting with Fourier Convolutions',
'https://arxiv.org/abs/2109.07161',
'https://github.com/saic-mdal/lama'
)
case AIModel.LDM:
return renderLDMModelDesc()
default:
return <></>
}
}
return (
<SettingBlock
className="model-setting-block"
title="Inpainting Model"
input={
<Selector
value={setting.model as string}
options={Object.values(AIModel)}
onChange={val => onModelChange(val as AIModel)}
/>
}
optionDesc={renderOptionDesc()}
/>
)
}
export default ModelSettingBlock

View File

@ -0,0 +1,40 @@
import React from 'react'
import NumberInput from '../shared/NumberInput'
import SettingBlock from './SettingBlock'
interface NumberInputSettingProps {
title: string
value: string
suffix?: string
onValue: (val: string) => void
}
function NumberInputSetting(props: NumberInputSettingProps) {
const { title, value, suffix, onValue } = props
return (
<SettingBlock
className="sub-setting-block"
title={title}
input={
<div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: '8px',
}}
>
<NumberInput
style={{ width: '80px' }}
value={`${value}`}
onValue={onValue}
/>
{suffix && <span>{suffix}</span>}
</div>
}
/>
)
}
export default NumberInputSetting

View File

@ -0,0 +1,31 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import { Switch, SwitchThumb } from '../shared/Switch'
import SettingBlock from './SettingBlock'
function SavePathSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState)
const onCheckChange = (checked: boolean) => {
setSettingState(old => {
return { ...old, saveImageBesideOrigin: checked }
})
}
return (
<SettingBlock
title="Download image beside origin image"
input={
<Switch
checked={setting.saveImageBesideOrigin}
onCheckedChange={onCheckChange}
>
<SwitchThumb />
</Switch>
}
/>
)
}
export default SavePathSettingBlock

View File

@ -0,0 +1,36 @@
.setting-block {
display: flex;
flex-direction: column;
.option-desc {
color: var(--text-color-gray);
margin-top: 12px;
border: 1px solid var(--border-color);
border-radius: 0.3rem;
padding: 1rem;
.sub-setting-block {
margin-top: 8px;
color: var(--text-color);
}
}
}
.setting-block-content {
display: flex;
justify-content: space-between;
align-items: center;
gap: 12rem;
}
.setting-block-content-title {
display: flex;
flex-direction: column;
justify-content: space-between;
}
.setting-block-desc {
font-size: 1rem;
margin-top: 8px;
color: var(--text-color-gray);
}

View File

@ -0,0 +1,27 @@
import React, { ReactNode } from 'react'
interface SettingBlockProps {
title: string
desc?: string
input: ReactNode
optionDesc?: ReactNode
className?: string
}
function SettingBlock(props: SettingBlockProps) {
const { title, desc, input, optionDesc, className } = props
return (
<div className={`setting-block ${className}`}>
<div className="setting-block-content">
<div className="setting-block-content-title">
<span>{title}</span>
{desc && <span className="setting-block-desc">{desc}</span>}
</div>
{input}
</div>
{optionDesc && <div className="option-desc">{optionDesc}</div>}
</div>
)
}
export default SettingBlock

View File

@ -0,0 +1,46 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import Button from '../shared/Button'
const SettingIcon = () => {
const [setting, setSettingState] = useRecoilState(settingState)
const onClick = () => {
setSettingState({ ...setting, show: !setting.show })
}
return (
<div>
<Button
onClick={onClick}
style={{ border: 0 }}
icon={
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
role="img"
width="28"
height="28"
viewBox="0 0 24 24"
stroke="currentColor"
strokeWidth="2"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
/>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
/>
</svg>
}
/>
</div>
)
}
export default SettingIcon

View File

@ -0,0 +1,20 @@
@use '../../styles/Mixins/' as *;
@import './SettingBlock.scss';
@import './HDSettingBlock.scss';
@import './ModelSettingBlock.scss';
.modal-setting {
grid-area: main-content;
background-color: var(--modal-bg);
color: var(--modal-text-color);
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
width: 700px;
@include mobile {
display: grid;
width: 100%;
height: auto;
margin-top: -11rem;
animation: slideDown 0.2s ease-out;
}
}

View File

@ -0,0 +1,37 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import Modal from '../shared/Modal'
import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock'
interface SettingModalProps {
onClose: () => void
}
export default function SettingModal(props: SettingModalProps) {
const { onClose } = props
const [setting, setSettingState] = useRecoilState(settingState)
const handleOnClose = () => {
setSettingState(old => {
return { ...old, show: false }
})
onClose()
}
return (
<Modal
onClose={handleOnClose}
title="Settings"
className="modal-setting"
show={setting.show}
>
{/* It's not possible because this poses a security risk */}
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
{/* <SavePathSettingBlock /> */}
<ModelSettingBlock />
<HDSettingBlock />
</Modal>
)
}

View File

@ -1,5 +1,5 @@
import React, { ReactNode } from 'react'
import { useSetRecoilState } from 'recoil'
import { useRecoilState } from 'recoil'
import { shortcutsState } from '../../store/Atoms'
import Modal from '../shared/Modal'
@ -19,13 +19,8 @@ function ShortCut(props: Shortcut) {
)
}
interface ShortcutsModalProps {
show: boolean
}
export default function ShortcutsModal(props: ShortcutsModalProps) {
const { show } = props
const setShortcutState = useSetRecoilState(shortcutsState)
export default function ShortcutsModal() {
const [shortcutsShow, setShortcutState] = useRecoilState(shortcutsState)
const shortcutStateHandler = () => {
setShortcutState(false)
@ -36,7 +31,7 @@ export default function ShortcutsModal(props: ShortcutsModalProps) {
onClose={shortcutStateHandler}
title="Hotkeys"
className="modal-shortcuts"
show={show}
show={shortcutsShow}
>
<div className="shortcut-options">
<ShortCut content="Enable multi-stroke mask drawing">

View File

@ -1,19 +1,99 @@
import React from 'react'
import { useRecoilValue } from 'recoil'
import React, { useEffect } from 'react'
import { useRecoilState } from 'recoil'
import Editor from './Editor/Editor'
import { shortcutsState } from '../store/Atoms'
import ShortcutsModal from './Shortcuts/ShortcutsModal'
import SettingModal from './Settings/SettingsModal'
import Toast from './shared/Toast'
import { Settings, settingState, toastState } from '../store/Atoms'
import {
currentModel,
modelDownloaded,
switchModel,
} from '../adapters/inpainting'
import { AIModel } from './Settings/ModelSettingBlock'
interface WorkspaceProps {
file: File
}
const Workspace = ({ file }: WorkspaceProps) => {
const shortcutVisbility = useRecoilValue(shortcutsState)
const [settings, setSettingState] = useRecoilState(settingState)
const [toastVal, setToastState] = useRecoilState(toastState)
const onSettingClose = async () => {
const curModel = await currentModel().then(res => res.text())
if (curModel === settings.model) {
return
}
const downloaded = await modelDownloaded(settings.model).then(res =>
res.text()
)
const { model } = settings
let loadingMessage = `Switching to ${model} model`
let loadingDuration = 3000
if (downloaded === 'False') {
loadingMessage = `Downloading ${model} model, this may take a while`
loadingDuration = 9999999999
}
setToastState({
open: true,
desc: loadingMessage,
state: 'loading',
duration: loadingDuration,
})
switchModel(model)
.then(res => {
if (res.ok) {
setToastState({
open: true,
desc: `Switch to ${model} model success`,
state: 'success',
duration: 3000,
})
} else {
throw new Error('Server error')
}
})
.catch(() => {
setToastState({
open: true,
desc: `Switch to ${model} model failed`,
state: 'error',
duration: 3000,
})
setSettingState(old => {
return { ...old, model: curModel as AIModel }
})
})
}
useEffect(() => {
currentModel()
.then(res => res.text())
.then(model => {
setSettingState(old => {
return { ...old, model: model as AIModel }
})
})
}, [])
return (
<>
<Editor file={file} />
<ShortcutsModal show={shortcutVisbility} />
<SettingModal onClose={onSettingClose} />
<ShortcutsModal />
<Toast
{...toastVal}
onOpenChange={(open: boolean) => {
setToastState(old => {
return { ...old, open }
})
}}
/>
</>
)
}

View File

@ -1,6 +1,6 @@
import { XIcon } from '@heroicons/react/outline'
import React, { ReactNode, useRef } from 'react'
import { useClickAway, useKey } from 'react-use'
import { useClickAway, useKey, useKeyPress, useKeyPressEvent } from 'react-use'
import Button from './Button'
export interface ModalProps {
@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
const ref = useRef(null)
useClickAway(ref, () => {
onClose?.()
if (show) {
onClose?.()
}
})
useKey('Escape', onClose, {
event: 'keydown',
useKeyPressEvent('Escape', e => {
if (show) {
onClose?.()
}
})
return (
@ -30,7 +34,7 @@ export default function Modal(props: ModalProps) {
>
<div ref={ref} className={`modal ${className}`}>
<div className="modal-header">
<h3>{title}</h3>
<h2>{title}</h2>
<Button icon={<XIcon />} onClick={onClose} />
</div>
{children}

View File

@ -0,0 +1,12 @@
.number-input {
all: unset;
flex: 1 0 auto;
border-radius: 0.5rem;
padding: 0.2rem 0.8rem;
line-height: 1;
outline: 1px solid var(--border-color);
&:focus-visible {
outline: 1px solid var(--yellow-accent);
}
}

View File

@ -0,0 +1,31 @@
import React, { FormEvent, InputHTMLAttributes } from 'react'
interface NumberInputProps extends InputHTMLAttributes<HTMLInputElement> {
value: string
onValue?: (val: string) => void
}
const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
(props: NumberInputProps, forwardedRef) => {
const { value, onValue, ...itemProps } = props
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
const target = evt.target as HTMLInputElement
const val = target.value.replace(/\D/g, '')
onValue?.(val)
}
return (
<input
value={value}
onInput={handleOnInput}
className="number-input"
{...itemProps}
ref={forwardedRef}
type="text"
/>
)
}
)
export default NumberInput

View File

@ -0,0 +1,74 @@
@use '../../styles/Mixins' as *;
.selector {
position: relative;
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-between;
}
.selector-main {
@include accented-display(var(white));
width: 100%;
user-select: none;
display: flex;
justify-content: space-between;
align-items: center;
cursor: pointer;
outline: none;
gap: 8px;
padding: 0.2rem 0.8rem;
border: 1px solid var(--border-color);
color: var(--options-text-color);
svg {
width: 1rem;
height: 1rem;
margin-top: 0.25rem;
}
}
.selector-options {
@include accented-display(var(--btn-primary-bg));
width: 100%;
padding: 0;
display: grid;
justify-self: center;
position: absolute;
cursor: pointer;
top: 3rem;
color: var(--options-text-color);
background-color: var(--page-bg);
border: 1px solid var(--border-color);
border-radius: 0.6rem;
@include mobile {
bottom: 11.5rem;
}
.selector-option {
display: flex;
align-items: center;
user-select: none;
padding: 0.5rem 0.8rem;
&:first-of-type {
border-top-right-radius: 0.5rem;
border-top-left-radius: 0.5rem;
}
&:last-of-type {
border-bottom-left-radius: 0.5rem;
border-bottom-right-radius: 0.5rem;
}
&:hover {
background-color: var(--yellow-accent);
color: var(--btn-text-hover-color);
}
}
}

View File

@ -0,0 +1,86 @@
import React, { MutableRefObject, useCallback, useRef, useState } from 'react'
import { useClickAway, useKeyPressEvent } from 'react-use'
import { ChevronDownIcon, ChevronUpIcon } from '@heroicons/react/outline'
type SelectorChevronDirection = 'up' | 'down'
type SelectorProps = {
minWidth?: number
chevronDirection?: SelectorChevronDirection
value: string
options: string[]
onChange: (value: string) => void
}
const selectorDefaultProps = {
minWidth: 128,
chevronDirection: 'down',
}
function Selector(props: SelectorProps) {
const { minWidth, chevronDirection, value, options, onChange } = props
const [showOptions, setShowOptions] = useState<boolean>(false)
const selectorRef = useRef<HTMLDivElement | null>(null)
const showOptionsHandler = () => {
// console.log(selectorRef.current?.focus)
// selectorRef?.current?.focus()
setShowOptions(currentShowOptionsState => !currentShowOptionsState)
}
useClickAway(selectorRef, () => {
setShowOptions(false)
})
// TODO: how to prevent Modal close?
// useKeyPressEvent('Escape', (e: KeyboardEvent) => {
// if (showOptions === true) {
// console.log(`selector ${e}`)
// e.preventDefault()
// e.stopPropagation()
// setShowOptions(false)
// }
// })
const onOptionClick = (e: any, newIndex: number) => {
const currentRes = e.target.textContent.split('x')
onChange(currentRes[0])
setShowOptions(false)
}
return (
<div className="selector" ref={selectorRef} style={{ minWidth }}>
<div
className="selector-main"
role="button"
onClick={showOptionsHandler}
aria-hidden="true"
>
<p>{value}</p>
<div className="selector-icon">
{chevronDirection === 'up' ? <ChevronUpIcon /> : <ChevronDownIcon />}
</div>
</div>
{showOptions && (
<div className="selector-options">
{options.map((val, _index) => (
<div
className="selector-option"
role="button"
tabIndex={0}
key={val}
onClick={e => onOptionClick(e, _index)}
aria-hidden="true"
>
{val}
</div>
))}
</div>
)}
</div>
)
}
Selector.defaultProps = selectorDefaultProps
export default Selector

View File

@ -0,0 +1,36 @@
.switch-root {
all: 'unset';
width: 42px;
height: 25px;
background-color: var(--switch-root-background-color);
border-radius: 9999px;
border: none;
position: relative;
transition: background-color 100ms;
-webkit-tap-highlight-color: rgba(0, 0, 0, 0);
&:focus-visible {
outline: none;
}
}
.switch-root[data-state='checked'] {
background-color: var(--yellow-accent);
}
.switch-thumb {
display: block;
width: 17px;
height: 17px;
background-color: var(--switch-thumb-color);
border-radius: 9999px;
transition: transform 100ms;
transform: translateX(4px);
will-change: transform;
}
.switch-thumb[data-state='checked'] {
transform: translateX(21px);
background-color: var(--switch-thumb-checked-color);
outline: 1px solid rgb(100, 100, 120, 0.5);
}

View File

@ -0,0 +1,34 @@
import React from 'react'
import * as SwitchPrimitive from '@radix-ui/react-switch'
const Switch = React.forwardRef<
React.ElementRef<typeof SwitchPrimitive.Root>,
React.ComponentProps<typeof SwitchPrimitive.Root>
>((props, forwardedRef) => {
const { className, ...itemProps } = props
return (
<SwitchPrimitive.Root
{...itemProps}
ref={forwardedRef}
className={`switch-root ${className}`}
/>
)
})
const SwitchThumb = React.forwardRef<
React.ElementRef<typeof SwitchPrimitive.Thumb>,
React.ComponentProps<typeof SwitchPrimitive.Thumb>
>((props, forwardedRef) => {
const { className, ...itemProps } = props
return (
<SwitchPrimitive.Thumb
{...itemProps}
ref={forwardedRef}
className={`switch-thumb ${className}`}
/>
)
})
export { Switch, SwitchThumb }

View File

@ -0,0 +1,83 @@
.toast-viewpoint {
position: fixed;
top: 48px;
right: 0;
display: flex;
flex-direction: row;
padding: 25px;
gap: 10px;
max-width: 100vw;
margin: 0;
z-index: 999999;
&:focus-visible {
outline: none;
}
}
.toast-root {
border: 1px solid var(--border-color-light);
background-color: var(--page-bg);
border-radius: 0.6rem;
padding: 15px;
display: flex;
align-items: center;
gap: 12px;
&[data-state='open'] {
animation: slideIn 150ms cubic-bezier(0.16, 1, 0.3, 1);
}
&[data-state='close'] {
animation: opacityReveal 100ms ease-in forwards;
}
&[data-state='cancel'] {
transform: translateX(0);
animation: transform 100ms ease-out;
}
&.error {
border: 1px solid var(--error-color);
}
&.success {
border: 1px solid var(--success-color);
}
}
.error-icon {
height: 24px;
width: 24px;
color: var(--error-color);
}
.success-icon {
height: 24px;
width: 24px;
color: var(--success-color);
}
.loading-icon {
display: flex;
align-items: center;
animation-name: spin;
animation-duration: 1500ms;
animation-iteration-count: infinite;
transform-origin: center center;
animation-timing-function: linear;
}
.toast-icon {
display: flex;
align-items: center;
}
.toast-desc {
display: flex;
align-items: center;
margin: 0;
color: var(--text-color);
min-width: 240px;
}

View File

@ -0,0 +1,81 @@
import * as React from 'react'
import * as ToastPrimitive from '@radix-ui/react-toast'
import { ToastProps } from '@radix-ui/react-toast'
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline'
const LoadingIcon = () => {
return (
<span className="loading-icon">
<svg
xmlns="http://www.w3.org/2000/svg"
width="20"
height="20"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<line x1="12" y1="2" x2="12" y2="6" />
<line x1="12" y1="18" x2="12" y2="22" />
<line x1="4.93" y1="4.93" x2="7.76" y2="7.76" />
<line x1="16.24" y1="16.24" x2="19.07" y2="19.07" />
<line x1="2" y1="12" x2="6" y2="12" />
<line x1="18" y1="12" x2="22" y2="12" />
<line x1="4.93" y1="19.07" x2="7.76" y2="16.24" />
<line x1="16.24" y1="7.76" x2="19.07" y2="4.93" />
</svg>
</span>
)
}
export type ToastState = 'default' | 'error' | 'loading' | 'success'
interface MyToastProps extends ToastProps {
desc: string
state?: ToastState
}
const Toast = React.forwardRef<
React.ElementRef<typeof ToastPrimitive.Root>,
MyToastProps
>((props, forwardedRef) => {
const { state, desc, ...itemProps } = props
const getIcon = () => {
switch (state) {
case 'error':
return <ExclamationCircleIcon className="error-icon" />
case 'success':
return <CheckIcon className="success-icon" />
case 'loading':
return <LoadingIcon />
default:
return <></>
}
}
return (
<ToastPrimitive.Provider>
<ToastPrimitive.Root
{...itemProps}
ref={forwardedRef}
className={`toast-root ${state}`}
>
<div className="toast-icon">{getIcon()}</div>
<ToastPrimitive.Description className="toast-desc">
{desc}
</ToastPrimitive.Description>
</ToastPrimitive.Root>
<ToastPrimitive.Viewport className="toast-viewpoint" />
</ToastPrimitive.Provider>
)
})
Toast.defaultProps = {
desc: '',
state: 'loading',
}
export default Toast

View File

@ -8,14 +8,21 @@ export default function useInputImage() {
headers.append('pragma', 'no-cache')
headers.append('cache-control', 'no-cache')
fetch('/inputimage', { headers })
.then(res => res.blob())
.then(data => {
if (data && data.type.startsWith('image')) {
const userInput = new File([data], 'inputImage')
setInputImage(userInput)
}
})
fetch('/inputimage', { headers }).then(async res => {
const filename = res.headers
.get('content-disposition')
?.split('filename=')[1]
.split(';')[0]
const data = await res.blob()
if (data && data.type.startsWith('image')) {
const userInput = new File(
[data],
filename !== undefined ? filename : 'inputImage'
)
setInputImage(userInput)
}
})
}, [setInputImage])
useEffect(() => {

View File

@ -1,11 +1,82 @@
import { atom } from 'recoil'
import { HDStrategy } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Settings/ModelSettingBlock'
import { ToastState } from '../components/shared/Toast'
export const fileState = atom<File | undefined>({
key: 'fileState',
default: undefined,
})
interface ToastAtomState {
open: boolean
desc: string
state: ToastState
duration: number
}
export const toastState = atom<ToastAtomState>({
key: 'toastState',
default: {
open: false,
desc: '',
state: 'default',
duration: 3000,
},
})
export const shortcutsState = atom<boolean>({
key: 'shortcutsState',
default: false,
})
export interface Settings {
show: boolean
saveImageBesideOrigin: boolean
model: AIModel
// For LaMa
hdStrategy: HDStrategy
hdStrategyResizeLimit: number
hdStrategyCropTrigerSize: number
hdStrategyCropMargin: number
// For LDM
ldmSteps: number
}
export const settingStateDefault = {
show: false,
saveImageBesideOrigin: false,
model: AIModel.LAMA,
ldmSteps: 50,
hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048,
hdStrategyCropMargin: 128,
}
const localStorageEffect =
(key: string) =>
({ setSelf, onSet }: any) => {
const savedValue = localStorage.getItem(key)
if (savedValue != null) {
const storageSettings = JSON.parse(savedValue)
storageSettings.show = false
setSelf(storageSettings)
}
onSet((newValue: Settings, _: string, isReset: boolean) =>
isReset
? localStorage.removeItem(key)
: localStorage.setItem(key, JSON.stringify(newValue))
)
}
// Each atom can reference an array of these atom effect functions which are called in priority order when the atom is initialized
// https://recoiljs.org/docs/guides/atom-effects/#local-storage-persistence
export const settingState = atom<Settings>({
key: 'settingsState',
default: settingStateDefault,
effects: [localStorageEffect('settingsState')],
})

View File

@ -37,3 +37,21 @@
transform: translateY(0);
}
}
@keyframes slideIn {
0% {
transform: translateX(calc(100% + 25px));
}
100% {
transform: translateX(0);
}
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}

View File

@ -7,6 +7,10 @@
--yellow-accent: #ffcc00;
--link-color: rgb(0, 0, 0);
--border-color: rgb(100, 100, 120);
--border-color-light: rgba(100, 100, 120, 0.5);
--error-color: rgb(239, 68, 68);
--success-color: rgb(16, 185, 129);
// Editor
--editor-toolkit-bg: rgba(255, 255, 255, 0.5);
@ -17,13 +21,22 @@
--modal-bg: var(--page-bg);
--modal-text-color: rgb(0, 0, 0);
--modal-hotkey-border-color: rgb(0, 0, 0);
--model-mask-bg: rgba(209,213,219,0.4);
--model-mask-bg: rgba(209, 213, 219, 0.4);
// Text
--text-color: #040404;
--text-color-gray: rgb(107, 111, 118);
// Shared
--btn-primary-bg: rgb(210, 210, 220);
--btn-text-color: black;
--btn-text-hover-color: black;
--btn-text-color: var(--text-color);
--btn-text-hover-color: #040404;
--btn-border-color: rgb(100, 100, 120);
--btn-primary-hover-bg: var(--yellow-accent);
--animation-pulsing-bg: rgb(255, 255, 255, 0.5);
// switch
--switch-root-background-color: rgb(223, 225, 228);
--switch-thumb-color: var(--page-bg);
--switch-thumb-checked-color: var(--page-bg);
}

View File

@ -3,10 +3,11 @@
// Theme
--page-bg: #040404;
--page-text-color: #F9F9F9;
--page-text-color: #f9f9f9;
--yellow-accent: #ffcc00;
--link-color: var(--yellow-accent);
--border-color: rgb(100, 100, 120);
--border-color-light: rgba(102, 102, 102);
// Editor
--editor-toolkit-bg: rgba(0, 0, 0, 0.5);
@ -17,14 +18,23 @@
--modal-bg: var(--page-bg);
--modal-text-color: var(--page-text-color);
// --modal-hotkey-bg: rgb(60, 60, 90);
--modal-hotkey-border-color: var(--page-text-color);;
--modal-hotkey-border-color: var(--page-text-color);
--model-mask-bg: rgba(76, 76, 87, 0.4);
// Text
--text-color: white;
--text-color-gray: rgb(138, 143, 152);
// Shared
--btn-primary-bg: rgb(140, 140, 180);
--btn-text-color: white;
--btn-text-color: var(--text-color);
--btn-text-hover-color: var(--page-bg);
--btn-border-color: var(--yellow-accent);
--btn-primary-hover-bg: var(--yellow-accent);
--animation-pulsing-bg: rgb(240, 240, 255);
// switch
--switch-root-background-color: rgb(60, 63, 68);
--switch-thumb-color: rgb(31, 32, 35);
--switch-thumb-checked-color: white;
}

View File

@ -11,11 +11,16 @@
@use '../components/Header/Header';
@use '../components/Header/ThemeChanger';
@use '../components/Shortcuts/Shortcuts';
@use '../components/Settings/Settings.scss';
// Shared
@use '../components/FileSelect/FileSelect';
@use '../components/shared/Button';
@use '../components/shared/Modal';
@use '../components/shared/Selector';
@use '../components/shared/Switch';
@use '../components/shared/NumberInput';
@use '../components/shared/Toast';
// Main CSS
*,

View File

@ -1164,6 +1164,13 @@
dependencies:
regenerator-runtime "^0.13.4"
"@babel/runtime@^7.13.10":
version "7.17.9"
resolved "https://registry.npmmirror.com/@babel/runtime/-/runtime-7.17.9.tgz#d19fbf802d01a8cb6cf053a64e472d42c434ba72"
integrity sha512-lSiBBvodq29uShpWGNbgFdKYNiFDo5/HIYsaCEY9ff4sb10x9jizo2+pRrSyF4jKZCXqgzuqBOQKbUm90gQwJg==
dependencies:
regenerator-runtime "^0.13.4"
"@babel/template@^7.10.4", "@babel/template@^7.15.4", "@babel/template@^7.3.3":
version "7.15.4"
resolved "https://registry.npmjs.org/@babel/template/-/template-7.15.4.tgz"
@ -1537,6 +1544,186 @@
schema-utils "^2.6.5"
source-map "^0.7.3"
"@radix-ui/primitive@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-0.1.0.tgz#6206b97d379994f0d1929809db035733b337e543"
integrity sha512-tqxZKybwN5Fa3VzZry4G6mXAAb9aAqKmPtnVbZpL0vsBwvOHTBwsjHVPXylocYLwEtBY9SCe665bYnNB515uoA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-compose-refs/-/react-compose-refs-0.1.0.tgz#cff6e780a0f73778b976acff2c2a5b6551caab95"
integrity sha512-eyclbh+b77k+69Dk72q3694OHrn9B3QsoIRx7ywX341U9RK1ThgQjMFZoPtmZNQTksXHLNEiefR8hGVeFyInGg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-context@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-0.1.1.tgz#06996829ea124d9a1bc1dbe3e51f33588fab0875"
integrity sha512-PkyVX1JsLBioeu0jB9WvRpDBBLtLZohVDT3BB5CTSJqActma8S8030P57mWZb4baZifMvN7KKWPAA40UmWKkQg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-dismissable-layer@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-0.1.5.tgz#9379032351e79028d472733a5cc8ba4a0ea43314"
integrity sha512-J+fYWijkX4M4QKwf9dtu1oC0U6e6CEl8WhBp3Ad23yz2Hia0XCo6Pk/mp5CAFy4QBtQedTSkhW05AdtSOEoajQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "0.1.0"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-body-pointer-events" "0.1.1"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-escape-keydown" "0.1.0"
"@radix-ui/react-id@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
integrity sha512-IPc4H/63bes0IZ1GJJozSEkSWcDyhNGtKFWUpJ+XtaLyQ1X3x7Mf6fWwWhDcpqlYEP+5WtAvfqcyEsyjP+ZhBQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-label@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-label/-/react-label-0.1.5.tgz#12cd965bfc983e0148121d4c99fb8e27a917c45c"
integrity sha512-Au9+n4/DhvjR0IHhvZ1LPdx/OW+3CGDie30ZyCkbSHIuLp4/CV4oPPGBwJ1vY99Jog3zyQhsGww9MXj8O9Aj/A==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-context" "0.1.1"
"@radix-ui/react-id" "0.1.5"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-portal@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2"
integrity sha512-MO0wRy2eYRTZ/CyOri9NANCAtAtq89DEtg90gicaTlkCfdqCLEBsLb+/q66BZQTr3xX/Vq01nnVfc/TkCqoqvw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-presence@0.1.2":
version "0.1.2"
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3"
integrity sha512-3BRlFZraooIUfRlyN+b/Xs5hq1lanOOo/+3h6Pwu2GMFjkGKKa4Rd51fcqGqnVlbr3jYg+WLuGyAV4KlgqwrQw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-primitive@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
integrity sha512-6gSl2IidySupIMJFjYnDIkIWRyQdbu/AHK7rbICPani+LW4b0XdxBXc46og/iZvuwW8pjCS8I2SadIerv84xYA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-slot" "0.1.2"
"@radix-ui/react-slot@0.1.2":
version "0.1.2"
resolved "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-0.1.2.tgz#e6f7ad9caa8ce81cc8d532c854c56f9b8b6307c8"
integrity sha512-ADkqfL+agEzEguU3yS26jfB50hRrwf7U4VTwAOZEmi/g+ITcBWe12yM46ueS/UCIMI9Py+gFUaAdxgxafFvY2Q==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-switch@^0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-switch/-/react-switch-0.1.5.tgz#071ffa19a17a47fdc5c5e6f371bd5901c9fef2f4"
integrity sha512-ITtslJPK+Yi34iNf7K9LtsPaLD76oRIVzn0E8JpEO5HW8gpRBGb2NNI9mxKtEB30TVqIcdjdL10AmuIfOMwjtg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "0.1.0"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-context" "0.1.1"
"@radix-ui/react-label" "0.1.5"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-controllable-state" "0.1.0"
"@radix-ui/react-use-previous" "0.1.1"
"@radix-ui/react-use-size" "0.1.1"
"@radix-ui/react-toast@^0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-toast/-/react-toast-0.1.1.tgz#d544e796b307e56f1298e40f356f468680958e93"
integrity sha512-9JWC4mPP78OE6muDrpaPf/71dIeozppdcnik1IvsjTxZpDnt9PbTtQj94DdWjlCphbv3S5faD3KL0GOpqKBpTQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "0.1.0"
"@radix-ui/react-compose-refs" "0.1.0"
"@radix-ui/react-context" "0.1.1"
"@radix-ui/react-dismissable-layer" "0.1.5"
"@radix-ui/react-portal" "0.1.4"
"@radix-ui/react-presence" "0.1.2"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-controllable-state" "0.1.0"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-visually-hidden" "0.1.4"
"@radix-ui/react-use-body-pointer-events@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-body-pointer-events/-/react-use-body-pointer-events-0.1.1.tgz#63e7fd81ca7ffd30841deb584cd2b7f460df2597"
integrity sha512-R8leV2AWmJokTmERM8cMXFHWSiv/fzOLhG/JLmRBhLTAzOj37EQizssq4oW0Z29VcZy2tODMi9Pk/htxwb+xpA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "0.1.0"
"@radix-ui/react-use-callback-ref@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f"
integrity sha512-Va041McOFFl+aV+sejvl0BS2aeHx86ND9X/rVFmEFQKTXCp6xgUK0NGUAGcgBlIjnJSbMYPGEk1xKSSlVcN2Aw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-controllable-state@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-0.1.0.tgz#4fced164acfc69a4e34fb9d193afdab973a55de1"
integrity sha512-zv7CX/PgsRl46a52Tl45TwqwVJdmqnlQEQhaYMz/yBOD2sx2gCkCFSoF/z9mpnYWmS6DTLNTg5lIps3fV6EnXg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-escape-keydown@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874"
integrity sha512-tDLZbTGFmvXaazUXXv8kYbiCcbAE8yKgng9s95d8fCO+Eundv0Jngbn/hKPhDDs4jj9ChwRX5cDDnlaN+ugYYQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "0.1.0"
"@radix-ui/react-use-layout-effect@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
integrity sha512-+wdeS51Y+E1q1Wmd+1xSSbesZkpVj4jsg0BojCbopWvgq5iBvixw5vgemscdh58ep98BwUbsFYnrywFhV9yrVg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-previous@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-previous/-/react-use-previous-0.1.1.tgz#0226017f72267200f6e832a7103760e96a6db5d0"
integrity sha512-O/ZgrDBr11dR8rhO59ED8s5zIXBRFi8MiS+CmFGfi7MJYdLbfqVOmQU90Ghf87aifEgWe6380LA69KBneaShAg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-size@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-size/-/react-use-size-0.1.1.tgz#f6b75272a5d41c3089ca78c8a2e48e5f204ef90f"
integrity sha512-pTgWM5qKBu6C7kfKxrKPoBI2zZYZmp2cSXzpUiGM3qEBQlMLtYhaY2JXdXUCxz+XmD1YEjc8oRwvyfsD4AG4WA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-visually-hidden@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03"
integrity sha512-K/q6AEEzqeeEq/T0NPChvBqnwlp8Tl4NnQdrI/y8IOY7BRR+Ug0PEsVk6g48HJ7cA1//COugdxXXVVK/m0X1mA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "0.1.4"
"@rollup/plugin-node-resolve@^7.1.1":
version "7.1.3"
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"

View File

@ -9,7 +9,7 @@ import torch
from torch.hub import download_url_to_file, get_dir
def download_model(url):
def get_cache_path_by_url(url):
parts = urlparse(url)
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
@ -17,6 +17,11 @@ def download_model(url):
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
return cached_file
def download_model(url):
cached_file = get_cache_path_by_url(url)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
@ -31,7 +36,11 @@ def ceil_modulo(x, mod):
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
data = cv2.imencode(f".{ext}", image_numpy)[1]
data = cv2.imencode(f".{ext}", image_numpy,
[
int(cv2.IMWRITE_JPEG_QUALITY), 100,
int(cv2.IMWRITE_PNG_COMPRESSION), 0
])[1]
image_bytes = data.tobytes()
return image_bytes
@ -74,13 +83,24 @@ def resize_max_size(
return np_img
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
def pad_img_to_modulo(img: np.ndarray, mod: int):
"""
Args:
img: [H, W, C]
mod:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(
img,
((0, 0), (0, out_height - height), (0, out_width - width)),
((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric",
)
@ -88,15 +108,13 @@ def pad_img_to_modulo(img, mod):
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
"""
Args:
mask: (1, h, w) 0~1
mask: (h, w, 1) 0~255
Returns:
"""
height, width = mask.shape[1:]
_, thresh = cv2.threshold(
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
)
height, width = mask.shape[:2]
_, thresh = cv2.threshold(mask, 127, 255, 0)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []

View File

@ -1,121 +0,0 @@
import os
from typing import List
import cv2
import torch
import numpy as np
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa:
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
"""
Args:
crop_trigger_size: h, w
crop_margin:
device:
"""
self.crop_trigger_size = crop_trigger_size
self.crop_margin = crop_margin
self.device = device
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
print(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
area = image.shape[1] * image.shape[2]
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
return self._run(image, mask)
print("Trigger crop image")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box)
crop_result.append((crop_image, crop_box))
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
image[y1:y2, x1:x2, :] = crop_image
return image
def _run_box(self, image, mask, box):
"""
Args:
image: [C, H, W] RGB
mask: [1, H, W]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[1:]
w = box_w + self.crop_margin * 2
h = box_h + self.crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[:, t:b, l:r]
crop_mask = mask[:, t:b, l:r]
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._run(crop_img, crop_mask), [l, t, r, b]
def _run(self, image, mask):
"""
image: [C, H, W] RGB
mask: [1, H, W]
return: BGR IMAGE
"""
device = self.device
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

View File

127
lama_cleaner/model/base.py Normal file
View File

@ -0,0 +1,127 @@
import abc
import cv2
import torch
from loguru import logger
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
from lama_cleaner.schema import Config, HDStrategy
class InpaintModel:
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
self.device = device
self.init_model(device)
@abc.abstractmethod
def init_model(self, device):
...
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool:
...
@abc.abstractmethod
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
result = self.forward(padd_image, padd_mask, config)
result = result[0:origin_height, 0:origin_width, :]
original_pixel_indices = mask != 255
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return result
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
image: [H, W, C] RGB, not normalized
mask: [H, W]
return: BGR IMAGE
"""
inpaint_result = None
logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info(f"Run crop strategy")
boxes = boxes_from_mask(mask)
crop_result = []
for box in boxes:
crop_image, crop_box = self._run_box(image, mask, box, config)
crop_result.append((crop_image, crop_box))
inpaint_result = image[:, :, ::-1]
for crop_image, crop_box in crop_result:
x1, y1, x2, y2 = crop_box
inpaint_result[y1:y2, x1:x2, :] = crop_image
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC)
original_pixel_indices = mask != 255
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
def _run_box(self, image, mask, box, config: Config):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1]
box: [left,top,right,bottom]
Returns:
BGR IMAGE
"""
box_h = box[3] - box[1]
box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2
img_h, img_w = image.shape[:2]
w = box_w + config.hd_strategy_crop_margin * 2
h = box_h + config.hd_strategy_crop_margin * 2
l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0)
r = min(cx + w // 2, img_w)
b = min(cy + h // 2, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]

View File

@ -0,0 +1,68 @@
import os
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
class LaMa(InpaintModel):
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
super().__init__(device)
self.device = device
def init_model(self, device):
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
logger.info(f"Load LaMa model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
model.eval()
self.model = model
self.model_path = model_path
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return cur_res

View File

@ -2,13 +2,16 @@ import os
import numpy as np
import torch
from loguru import logger
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
torch.manual_seed(42)
import torch.nn as nn
from tqdm import tqdm
import cv2
from lama_cleaner.helper import pad_img_to_modulo, download_model
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
timestep_embedding
LDM_ENCODE_MODEL_URL = os.environ.get(
@ -217,7 +220,7 @@ class DDIMSampler(object):
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
@ -263,102 +266,54 @@ class DDIMSampler(object):
def load_jit_model(url, device):
model_path = download_model(url)
logger.info(f"Load LDM model from: {model_path}")
model = torch.jit.load(model_path).to(device)
model.eval()
return model
class LDM:
def __init__(self, device, steps=50):
class LDM(InpaintModel):
pad_mod = 32
def __init__(self, device):
super().__init__(device)
self.device = device
def init_model(self, device):
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
self.steps = steps
def _norm(self, tensor):
return tensor * 2.0 - 1.0
@staticmethod
def is_downloaded() -> bool:
model_paths = [
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
]
return all([os.path.exists(it) for it in model_paths])
@torch.no_grad()
def __call__(self, image, mask):
def forward(self, image, mask, config: Config):
"""
image: [C, H, W] RGB
mask: [1, H, W]
image: [H, W, C] RGB
mask: [H, W, 1]
return: BGR IMAGE
"""
# image [1,3,512,512] float32
# mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=32)
mask = pad_img_to_modulo(mask, mod=32)
padded_height, padded_width = image.shape[1:]
steps = config.ldm_steps
image = norm_img(image)
mask = norm_img(mask)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# crop 512 x 512
if padded_width <= 512 or padded_height <= 512:
np_img = self._forward(image, mask, self.device)
else:
print("Try to zoom in")
# zoom in
# x,y,w,h
# box = self.box_from_bitmap(mask)
box = self.find_main_content(mask)
if box is None:
print("No bbox found")
np_img = self._forward(image, mask, self.device)
else:
print(f"box: {box}")
box_x, box_y, box_w, box_h = box
cx = box_x + box_w // 2
cy = box_y + box_h // 2
# w = max(512, box_w)
# h = max(512, box_h)
w = box_w + 512
h = box_h + 512
left = max(cx - w // 2, 0)
top = max(cy - h // 2, 0)
right = min(cx + w // 2, origin_width)
bottom = min(cy + h // 2, origin_height)
x = left
y = top
w = right - left
h = bottom - top
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
print(f"Apply zoom in size width x height: {crop_img.shape}")
crop_img_height, crop_img_width = crop_img.shape[1:]
crop_img = pad_img_to_modulo(crop_img, mod=32)
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
# RGB
np_img = self._forward(crop_img, crop_mask, self.device)
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
np_img = image
# BGR to RGB
# np_img = image[:, :, ::-1]
np_img = np_img[0:origin_height, 0:origin_width, :]
np_img = np_img[:, :, ::-1]
return np_img
def _forward(self, image, mask, device):
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image
image = self._norm(image)
@ -371,47 +326,20 @@ class LDM:
c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample(steps=self.steps,
samples_ddim = self.sampler.sample(steps=steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape)
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
np_img = inpainted.astype(np.uint8)
return np_img
# inpainted = (1 - mask) * image + mask * predicted_image
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
return inpainted_image
def find_main_content(self, bitmap: np.ndarray):
th2 = bitmap[0].astype(np.uint8)
row_sum = th2.sum(1)
col_sum = th2.sum(0)
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
return left, top, right - left, bottom - top
def box_from_bitmap(self, bitmap):
"""
bitmap: single map with shape (NUM_CLASSES, H, W),
whose values are binarized as {0, 1}
"""
contours, _ = cv2.findContours(
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
)
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
num_contours = len(contours)
print(f"contours size: {num_contours}")
if num_contours != 1:
return None
# x,y,w,h
return cv2.boundingRect(contours[0])
def _norm(self, tensor):
return tensor * 2.0 - 1.0

View File

@ -0,0 +1,42 @@
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.schema import Config
class ModelManager:
LAMA = 'lama'
LDM = 'ldm'
def __init__(self, name: str, device):
self.name = name
self.device = device
self.model = self.init_model(name, device)
def init_model(self, name: str, device):
if name == self.LAMA:
model = LaMa(device)
elif name == self.LDM:
model = LDM(device)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
def is_downloaded(self, name: str) -> bool:
if name == self.LAMA:
return LaMa.is_downloaded()
elif name == self.LDM:
return LDM.is_downloaded()
else:
raise NotImplementedError(f"Not supported model: {name}")
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)
def switch(self, new_name: str):
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.name = new_name
except NotImplementedError as e:
raise e

View File

@ -0,0 +1,32 @@
import os
import imghdr
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
"--gui-size",
default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument(
"--input", type=str, help="Path to image you want to load by default"
)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if args.input is not None:
if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists")
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
return args

17
lama_cleaner/schema.py Normal file
View File

@ -0,0 +1,17 @@
from enum import Enum
from pydantic import BaseModel
class HDStrategy(str, Enum):
ORIGINAL = 'Original'
RESIZE = 'Resize'
CROP = 'Crop'
class Config(BaseModel):
ldm_steps: int
hd_strategy: str
hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int
hd_strategy_resize_limit: int

189
lama_cleaner/server.py Normal file
View File

@ -0,0 +1,189 @@
#!/usr/bin/env python3
import io
import logging
import multiprocessing
import os
import time
import imghdr
from pathlib import Path
from typing import Union
import cv2
import torch
import numpy as np
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from flask import Flask, request, send_file, cli
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
cli.show_server_banner = lambda *_: None
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
class NoFlaskwebgui(logging.Filter):
def filter(self, record):
return 'GET //flaskwebgui-keep-server-alive' not in record.getMessage()
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None
device = None
input_image_path: str = None
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time()
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
return send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
@app.route("/model")
def current_model():
return model.name, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@app.route("/model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route("/inputimage")
def set_input_photo():
if input_image_path:
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def main(args):
global model
global device
global input_image_path
device = torch.device(args.device)
input_image_path = args.input
model = ModelManager(name=args.model, device=device)
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
ui = FlaskUI(app, width=app_width, height=app_height, host=args.host, port=args.port)
ui.run()
else:
app.run(host=args.host, port=args.port, debug=args.debug)

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

BIN
lama_cleaner/tests/mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

View File

@ -1,15 +0,0 @@
import cv2
import numpy as np
from lama_cleaner.helper import boxes_from_mask
def test_boxes_from_mask():
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
mask = mask[:, :, np.newaxis]
mask = (mask / 255).transpose(2, 0, 1)
boxes = boxes_from_mask(mask)
print(boxes)
test_boxes_from_mask()

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import cv2
import numpy as np
import pytest
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
def get_data():
img = cv2.imread(str(current_dir / 'image.png'))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
return img, mask
def get_config(strategy):
return Config(
ldm_steps=1,
hd_strategy=strategy,
hd_strategy_crop_margin=32,
hd_strategy_crop_trigger_size=200,
hd_strategy_resize_limit=200,
)
def assert_equal(model, config, gt_name):
img, mask = get_data()
res = model(img, mask, config)
# cv2.imwrite(gt_name, res,
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
"""
Note that JPEG is lossy compression, so even if it is the highest quality 100,
when the saved image is reloaded, a difference occurs with the original pixel value.
If you want to save the original image as it is, save it as PNG or BMP.
"""
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
assert np.array_equal(res, gt)
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_lama(strategy):
model = ModelManager(name='lama', device='cpu')
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
def test_ldm(strategy):
model = ModelManager(name='ldm', device='cpu')
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')

210
main.py
View File

@ -1,210 +1,4 @@
#!/usr/bin/env python3
import argparse
import io
import multiprocessing
import os
import time
import imghdr
from typing import Union
import cv2
import torch
import numpy as np
from lama_cleaner.lama import LaMa
from lama_cleaner.ldm import LDM
from flaskwebgui import FlaskUI
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from flask import Flask, request, send_file
from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
norm_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app)
model = None
device = None
input_image_path: str = None
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image.shape}")
image = norm_img(image)
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask)
start = time.time()
res_np_img = model(image, mask)
print(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
)
res_np_img = np.concatenate(
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
)
ext = get_image_ext(origin_image_bytes)
return send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
mimetype=f"image/{ext}",
)
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@app.route("/inputimage")
def set_input_photo():
if input_image_path:
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
io.BytesIO(image_in_bytes),
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input", type=str, help="Path to image you want to load by default"
)
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument(
"--crop-trigger-size",
default=[2042, 2042],
nargs=2,
type=int,
help="If image size large then crop-trigger-size, "
"crop each area from original image to do inference."
"Mainly for performance and memory reasons"
"Only for lama",
)
parser.add_argument(
"--crop-margin",
type=int,
default=256,
help="Margin around bounding box of painted stroke when crop mode triggered",
)
parser.add_argument(
"--ldm-steps",
default=50,
type=int,
help="Steps for DDIM sampling process."
"The larger the value, the better the result, but it will be more time-consuming",
)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
"--gui-size",
default=[1600, 1000],
nargs=2,
type=int,
help="Set window size for GUI",
)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if args.input is not None:
if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists")
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
return args
def main():
global model
global device
global input_image_path
args = get_args_parser()
device = torch.device(args.device)
input_image_path = args.input
if args.model == "lama":
model = LaMa(
crop_trigger_size=args.crop_trigger_size,
crop_margin=args.crop_margin,
device=device,
)
elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps)
else:
raise NotImplementedError(f"Not supported model: {args.model}")
if args.gui:
app_width, app_height = args.gui_size
ui = FlaskUI(app, width=app_width, height=app_height)
ui.run()
else:
app.run(host="127.0.0.1", port=args.port, debug=args.debug)
from lama_cleaner import entry_point
if __name__ == "__main__":
main()
entry_point()

10
publish.sh Normal file
View File

@ -0,0 +1,10 @@
#!/usr/bin/env bash
set -e
pushd ./lama_cleaner/app
yarn run build
popd
rm -r -f dist
python3 setup.py sdist bdist_wheel
twine upload dist/*

2
requirements-dev.txt Normal file
View File

@ -0,0 +1,2 @@
wheel
twine

View File

@ -1,6 +1,9 @@
torch>=1.8.2
opencv-python
flask_cors
flask
flask>=2.1.1
flaskwebgui
tqdm
pydantic
loguru
pytest

View File

@ -1 +1,41 @@
# TODO: make this a python package
import setuptools
from pathlib import Path
web_files = Path("lama_cleaner/app/build/").glob("**/*")
web_files = [str(it).replace("lama_cleaner/", "") for it in web_files]
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
def load_requirements():
requirements_file_name = "requirements.txt"
requires = []
with open(requirements_file_name) as f:
for line in f:
if line:
requires.append(line.strip())
return requires
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
setuptools.setup(
name="lama-cleaner",
version="0.9.1",
author="PanicByte",
author_email="cwq1913@gmail.com",
description="Image inpainting tool powered by SOTA AI Model",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/Sanster/lama-cleaner",
packages=setuptools.find_packages("./"),
package_data={"lama_cleaner": web_files},
install_requires=load_requirements(),
python_requires=">=3.6",
entry_points={"console_scripts": ["lama-cleaner=lama_cleaner:entry_point"]},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
)