4
.gitignore
vendored
@ -3,3 +3,7 @@
|
||||
examples/
|
||||
.idea/
|
||||
.vscode/
|
||||
build
|
||||
!lama_cleaner/app/build
|
||||
dist/
|
||||
lama_cleaner.egg-info/
|
||||
|
21
README.md
@ -18,18 +18,15 @@ https://user-images.githubusercontent.com/3998421/153323093-b664bb68-2928-480b-b
|
||||
|
||||
Available commands for `main.py`
|
||||
|
||||
| Name | Description | Default |
|
||||
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------- |
|
||||
| --model | lama or ldm. See details in **Model Comparison** | lama |
|
||||
| --device | cuda or cpu | cuda |
|
||||
| --ldm-steps | The larger the value, the better the result, but it will be more time-consuming | 50 |
|
||||
| --crop-trigger-size | If image size large then crop-trigger-size, crop each area from original image to do inference. Mainly for performance and memory reasons on **very** large image. | 2042,2042 |
|
||||
| --crop-margin | Margin around bounding box of painted stroke when crop mode triggered. | 256 |
|
||||
| --gui | Launch lama-cleaner as a desktop application | |
|
||||
| --gui_size | Set the window size for the application | 1200 900 |
|
||||
| --input | Path to image you want to load by default | None |
|
||||
| --port | Port for flask web server | 8080 |
|
||||
| --debug | Enable debug mode for flask web server | |
|
||||
| Name | Description | Default |
|
||||
| ---------- | ------------------------------------------------ | -------- |
|
||||
| --model | lama or ldm. See details in **Model Comparison** | lama |
|
||||
| --device | cuda or cpu | cuda |
|
||||
| --gui | Launch lama-cleaner as a desktop application | |
|
||||
| --gui_size | Set the window size for the application | 1200 900 |
|
||||
| --input | Path to image you want to load by default | None |
|
||||
| --port | Port for flask web server | 8080 |
|
||||
| --debug | Enable debug mode for flask web server | |
|
||||
|
||||
## Model Comparison
|
||||
|
||||
|
7
lama_cleaner/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from lama_cleaner.parse_args import parse_args
|
||||
from lama_cleaner.server import main
|
||||
|
||||
|
||||
def entry_point():
|
||||
args = parse_args()
|
||||
main(args)
|
@ -17,6 +17,8 @@
|
||||
"project": "./tsconfig.json"
|
||||
},
|
||||
"rules": {
|
||||
"jsx-a11y/click-events-have-key-events": 0,
|
||||
"react/jsx-props-no-spreading": 0,
|
||||
"import/no-unresolved": 0,
|
||||
"react/jsx-no-bind": "off",
|
||||
"react/jsx-filename-extension": [
|
||||
|
@ -1,17 +1,17 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.fd853425.chunk.css",
|
||||
"main.js": "/static/js/main.c06ba56c.chunk.js",
|
||||
"main.css": "/static/css/main.5c7abe60.chunk.css",
|
||||
"main.js": "/static/js/main.0a74f667.chunk.js",
|
||||
"runtime-main.js": "/static/js/runtime-main.5e86ac81.js",
|
||||
"static/js/2.97604ba9.chunk.js": "/static/js/2.97604ba9.chunk.js",
|
||||
"static/js/2.4cb726d6.chunk.js": "/static/js/2.4cb726d6.chunk.js",
|
||||
"index.html": "/index.html",
|
||||
"static/js/2.97604ba9.chunk.js.LICENSE.txt": "/static/js/2.97604ba9.chunk.js.LICENSE.txt",
|
||||
"static/js/2.4cb726d6.chunk.js.LICENSE.txt": "/static/js/2.4cb726d6.chunk.js.LICENSE.txt",
|
||||
"static/media/_index.scss": "/static/media/WorkSans-SemiBold.1e98db4e.ttf"
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/js/runtime-main.5e86ac81.js",
|
||||
"static/js/2.97604ba9.chunk.js",
|
||||
"static/css/main.fd853425.chunk.css",
|
||||
"static/js/main.c06ba56c.chunk.js"
|
||||
"static/js/2.4cb726d6.chunk.js",
|
||||
"static/css/main.5c7abe60.chunk.css",
|
||||
"static/js/main.0a74f667.chunk.js"
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.fd853425.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>"localhost"===location.hostname&&(self.FIREBASE_APPCHECK_DEBUG_TOKEN=!0)</script><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.97604ba9.chunk.js"></script><script src="/static/js/main.c06ba56c.chunk.js"></script></body></html>
|
||||
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,maximum-scale=1,user-scalable=0"/><meta name="theme-color" content="#ffffff"/><title>lama-cleaner - Image inpainting powered by LaMa</title><link href="/static/css/main.5c7abe60.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>"localhost"===location.hostname&&(self.FIREBASE_APPCHECK_DEBUG_TOKEN=!0)</script><script>!function(e){function r(r){for(var n,l,a=r[0],f=r[1],i=r[2],p=0,s=[];p<a.length;p++)l=a[p],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in f)Object.prototype.hasOwnProperty.call(f,n)&&(e[n]=f[n]);for(c&&c(r);s.length;)s.shift()();return u.push.apply(u,i||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,a=1;a<t.length;a++){var f=t[a];0!==o[f]&&(n=!1)}n&&(u.splice(r--,1),e=l(l.s=t[0]))}return e}var n={},o={1:0},u=[];function l(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,l),t.l=!0,t.exports}l.m=e,l.c=n,l.d=function(e,r,t){l.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,r){if(1&r&&(e=l(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(l.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)l.d(t,n,function(r){return e[r]}.bind(null,n));return t},l.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(r,"a",r),r},l.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},l.p="/";var a=this["webpackJsonplama-cleaner"]=this["webpackJsonplama-cleaner"]||[],f=a.push.bind(a);a.push=r,a=a.slice();for(var i=0;i<a.length;i++)r(a[i]);var c=f;t()}([])</script><script src="/static/js/2.4cb726d6.chunk.js"></script><script src="/static/js/main.0a74f667.chunk.js"></script></body></html>
|
2
lama_cleaner/app/build/static/js/2.4cb726d6.chunk.js
Normal file
1
lama_cleaner/app/build/static/js/main.0a74f667.chunk.js
Normal file
@ -5,6 +5,8 @@
|
||||
"proxy": "http://localhost:8080",
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^1.0.4",
|
||||
"@radix-ui/react-switch": "^0.1.5",
|
||||
"@radix-ui/react-toast": "^0.1.1",
|
||||
"@testing-library/jest-dom": "^5.14.1",
|
||||
"@testing-library/react": "^12.1.2",
|
||||
"@testing-library/user-event": "^13.5.0",
|
||||
|
@ -1,10 +1,12 @@
|
||||
import { Settings } from '../store/Atoms'
|
||||
import { dataURItoBlob } from '../utils'
|
||||
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}/inpaint`
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
|
||||
export default async function inpaint(
|
||||
imageFile: File,
|
||||
maskBase64: string,
|
||||
settings: Settings,
|
||||
sizeLimit?: string
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
@ -13,13 +15,22 @@ export default async function inpaint(
|
||||
const mask = dataURItoBlob(maskBase64)
|
||||
fd.append('mask', mask)
|
||||
|
||||
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||
fd.append('hdStrategy', settings.hdStrategy)
|
||||
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
|
||||
fd.append(
|
||||
'hdStrategyCropTrigerSize',
|
||||
settings.hdStrategyCropTrigerSize.toString()
|
||||
)
|
||||
fd.append('hdStrategyResizeLimit', settings.hdStrategyResizeLimit.toString())
|
||||
|
||||
if (sizeLimit === undefined) {
|
||||
fd.append('sizeLimit', '1080')
|
||||
} else {
|
||||
fd.append('sizeLimit', sizeLimit)
|
||||
}
|
||||
|
||||
const res = await fetch(API_ENDPOINT, {
|
||||
const res = await fetch(`${API_ENDPOINT}/inpaint`, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
}).then(async r => {
|
||||
@ -28,3 +39,24 @@ export default async function inpaint(
|
||||
|
||||
return URL.createObjectURL(res)
|
||||
}
|
||||
|
||||
export function switchModel(name: string) {
|
||||
const fd = new FormData()
|
||||
fd.append('name', name)
|
||||
return fetch(`${API_ENDPOINT}/model`, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
})
|
||||
}
|
||||
|
||||
export function currentModel() {
|
||||
return fetch(`${API_ENDPOINT}/model`, {
|
||||
method: 'GET',
|
||||
})
|
||||
}
|
||||
|
||||
export function modelDownloaded(name: string) {
|
||||
return fetch(`${API_ENDPOINT}/model_downloaded/${name}`, {
|
||||
method: 'GET',
|
||||
})
|
||||
}
|
||||
|
@ -15,12 +15,14 @@ import {
|
||||
TransformComponent,
|
||||
TransformWrapper,
|
||||
} from 'react-zoom-pan-pinch'
|
||||
import { useRecoilValue } from 'recoil'
|
||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||
import inpaint from '../../adapters/inpainting'
|
||||
import Button from '../shared/Button'
|
||||
import Slider from './Slider'
|
||||
import SizeSelector from './SizeSelector'
|
||||
import { downloadImage, loadImage, useImage } from '../../utils'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
|
||||
const TOOLBAR_SIZE = 200
|
||||
const BRUSH_COLOR = '#ffcc00bb'
|
||||
@ -57,6 +59,7 @@ function drawLines(
|
||||
|
||||
export default function Editor(props: EditorProps) {
|
||||
const { file } = props
|
||||
const settings = useRecoilValue(settingState)
|
||||
const [brushSize, setBrushSize] = useState(40)
|
||||
const [original, isOriginalLoaded] = useImage(file)
|
||||
const [renders, setRenders] = useState<HTMLImageElement[]>([])
|
||||
@ -73,10 +76,11 @@ export default function Editor(props: EditorProps) {
|
||||
const [showOriginal, setShowOriginal] = useState(false)
|
||||
const [isInpaintingLoading, setIsInpaintingLoading] = useState(false)
|
||||
const [scale, setScale] = useState<number>(1)
|
||||
const [minScale, setMinScale] = useState<number>()
|
||||
const [minScale, setMinScale] = useState<number>(1.0)
|
||||
const [sizeLimit, setSizeLimit] = useState<number>(1080)
|
||||
const windowSize = useWindowSize()
|
||||
const viewportRef = useRef<ReactZoomPanPinchRef | undefined | null>()
|
||||
const [centered, setCentered] = useState(false)
|
||||
|
||||
const [isDraging, setIsDraging] = useState(false)
|
||||
const [isMultiStrokeKeyPressed, setIsMultiStrokeKeyPressed] = useState(false)
|
||||
@ -124,6 +128,7 @@ export default function Editor(props: EditorProps) {
|
||||
const res = await inpaint(
|
||||
file,
|
||||
maskCanvas.toDataURL(),
|
||||
settings,
|
||||
sizeLimit.toString()
|
||||
)
|
||||
if (!res) {
|
||||
@ -156,6 +161,7 @@ export default function Editor(props: EditorProps) {
|
||||
renders,
|
||||
sizeLimit,
|
||||
historyLineCount,
|
||||
settings,
|
||||
])
|
||||
|
||||
const hadDrawSomething = () => {
|
||||
@ -214,31 +220,38 @@ export default function Editor(props: EditorProps) {
|
||||
|
||||
// Draw once the original image is loaded
|
||||
useEffect(() => {
|
||||
if (!original) {
|
||||
if (!isOriginalLoaded) {
|
||||
return
|
||||
}
|
||||
|
||||
if (isOriginalLoaded) {
|
||||
const rW = windowSize.width / original.naturalWidth
|
||||
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
|
||||
if (rW < 1 || rH < 1) {
|
||||
const s = Math.min(rW, rH)
|
||||
setMinScale(s)
|
||||
setScale(s)
|
||||
} else {
|
||||
setMinScale(1)
|
||||
}
|
||||
const rW = windowSize.width / original.naturalWidth
|
||||
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
|
||||
|
||||
const imageSizeLimit = Math.max(original.width, original.height)
|
||||
setSizeLimit(imageSizeLimit)
|
||||
|
||||
if (context?.canvas) {
|
||||
context.canvas.width = original.naturalWidth
|
||||
context.canvas.height = original.naturalHeight
|
||||
}
|
||||
draw()
|
||||
let s = 1.0
|
||||
if (rW < 1 || rH < 1) {
|
||||
s = Math.min(rW, rH)
|
||||
}
|
||||
}, [context?.canvas, draw, original, isOriginalLoaded, windowSize])
|
||||
setMinScale(s)
|
||||
setScale(s)
|
||||
|
||||
const imageSizeLimit = Math.max(original.width, original.height)
|
||||
setSizeLimit(imageSizeLimit)
|
||||
|
||||
if (context?.canvas) {
|
||||
context.canvas.width = original.naturalWidth
|
||||
context.canvas.height = original.naturalHeight
|
||||
}
|
||||
viewportRef.current?.centerView(s, 1)
|
||||
setCentered(true)
|
||||
draw()
|
||||
}, [
|
||||
context?.canvas,
|
||||
draw,
|
||||
viewportRef,
|
||||
original,
|
||||
isOriginalLoaded,
|
||||
windowSize,
|
||||
])
|
||||
|
||||
// Zoom reset
|
||||
const resetZoom = useCallback(() => {
|
||||
@ -522,10 +535,6 @@ export default function Editor(props: EditorProps) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!original || !scale || !minScale) {
|
||||
return <></>
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="editor-container"
|
||||
@ -543,7 +552,7 @@ export default function Editor(props: EditorProps) {
|
||||
wheel={{ step: 0.05 }}
|
||||
centerZoomedOut
|
||||
alignmentAnimation={{ disabled: true }}
|
||||
centerOnInit
|
||||
// centerOnInit
|
||||
limitToBounds={false}
|
||||
doubleClick={{ disabled: true }}
|
||||
initialScale={minScale}
|
||||
@ -554,6 +563,9 @@ export default function Editor(props: EditorProps) {
|
||||
>
|
||||
<TransformComponent
|
||||
contentClass={isInpaintingLoading ? 'editor-canvas-loading' : ''}
|
||||
contentStyle={{
|
||||
visibility: centered ? 'visible' : 'hidden',
|
||||
}}
|
||||
>
|
||||
<div className="editor-canvas-container">
|
||||
<canvas
|
||||
|
@ -23,3 +23,11 @@ header {
|
||||
gap: 12px;
|
||||
justify-self: end;
|
||||
}
|
||||
|
||||
.header-icons {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
justify-self: end;
|
||||
}
|
@ -6,6 +6,7 @@ import Button from '../shared/Button'
|
||||
import Shortcuts from '../Shortcuts/Shortcuts'
|
||||
import useResolution from '../../hooks/useResolution'
|
||||
import { ThemeChanger } from './ThemeChanger'
|
||||
import SettingIcon from '../Settings/SettingIcon'
|
||||
|
||||
const Header = () => {
|
||||
const [file, setFile] = useRecoilState(fileState)
|
||||
@ -26,7 +27,11 @@ const Header = () => {
|
||||
</Button>
|
||||
</div>
|
||||
<div className="header-icons-wrapper">
|
||||
<div style={{ visibility: file ? 'visible' : 'hidden' }}>
|
||||
<div
|
||||
className="header-icons"
|
||||
style={{ visibility: file ? 'visible' : 'hidden' }}
|
||||
>
|
||||
<SettingIcon />
|
||||
<Shortcuts />
|
||||
</div>
|
||||
<ThemeChanger />
|
||||
|
@ -0,0 +1,7 @@
|
||||
.hd-setting-block {
|
||||
.inline-tip {
|
||||
display: inline;
|
||||
cursor: pointer;
|
||||
color: var(--text-color);
|
||||
}
|
||||
}
|
132
lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx
Normal file
@ -0,0 +1,132 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Selector from '../shared/Selector'
|
||||
import NumberInputSetting from './NumberInputSetting'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
export enum HDStrategy {
|
||||
ORIGINAL = 'Original',
|
||||
RESIZE = 'Resize',
|
||||
CROP = 'Crop',
|
||||
}
|
||||
|
||||
function HDSettingBlock() {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const onStrategyChange = (value: HDStrategy) => {
|
||||
setSettingState(old => {
|
||||
return { ...old, hdStrategy: value }
|
||||
})
|
||||
}
|
||||
|
||||
const onResizeLimitChange = (value: string) => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, hdStrategyResizeLimit: val }
|
||||
})
|
||||
}
|
||||
|
||||
const onCropTriggerSizeChange = (value: string) => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, hdStrategyCropTrigerSize: val }
|
||||
})
|
||||
}
|
||||
|
||||
const onCropMarginChange = (value: string) => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, hdStrategyCropMargin: val }
|
||||
})
|
||||
}
|
||||
|
||||
const renderOriginalOptionDesc = () => {
|
||||
return (
|
||||
<div>
|
||||
Use the original resolution of the picture, suitable for picture size
|
||||
below 2K. Try{' '}
|
||||
<div
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="inline-tip"
|
||||
onClick={() => onStrategyChange(HDStrategy.RESIZE)}
|
||||
>
|
||||
Resize Strategy
|
||||
</div>{' '}
|
||||
if you do not get good results on high resolution images.
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderResizeOptionDesc = () => {
|
||||
return (
|
||||
<div>
|
||||
<div>
|
||||
Resize the longer side of the image to a specific size(keep ratio),
|
||||
then do inpainting on the resized image.
|
||||
</div>
|
||||
<NumberInputSetting
|
||||
title="Size limit"
|
||||
value={`${setting.hdStrategyResizeLimit}`}
|
||||
suffix="pixel"
|
||||
onValue={onResizeLimitChange}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderCropOptionDesc = () => {
|
||||
return (
|
||||
<div>
|
||||
<div>
|
||||
Crop masking area from the original image to do inpainting, and paste
|
||||
the result back. Mainly for performance and memory reasons on high
|
||||
resolution image.
|
||||
</div>
|
||||
<NumberInputSetting
|
||||
title="Trigger size"
|
||||
value={`${setting.hdStrategyCropTrigerSize}`}
|
||||
suffix="pixel"
|
||||
onValue={onCropTriggerSizeChange}
|
||||
/>
|
||||
<NumberInputSetting
|
||||
title="Crop margin"
|
||||
value={`${setting.hdStrategyCropMargin}`}
|
||||
suffix="pixel"
|
||||
onValue={onCropMarginChange}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderHDStrategyOptionDesc = (): ReactNode => {
|
||||
switch (setting.hdStrategy) {
|
||||
case HDStrategy.ORIGINAL:
|
||||
return renderOriginalOptionDesc()
|
||||
case HDStrategy.CROP:
|
||||
return renderCropOptionDesc()
|
||||
case HDStrategy.RESIZE:
|
||||
return renderResizeOptionDesc()
|
||||
default:
|
||||
return renderOriginalOptionDesc()
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
className="hd-setting-block"
|
||||
title="High Resolution Strategy"
|
||||
input={
|
||||
<Selector
|
||||
value={setting.hdStrategy as string}
|
||||
options={Object.values(HDStrategy)}
|
||||
onChange={val => onStrategyChange(val as HDStrategy)}
|
||||
/>
|
||||
}
|
||||
optionDesc={renderHDStrategyOptionDesc()}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default HDSettingBlock
|
@ -0,0 +1,4 @@
|
||||
.model-desc-link {
|
||||
color: var(--text-color-gray);
|
||||
text-decoration: none;
|
||||
}
|
103
lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx
Normal file
@ -0,0 +1,103 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Selector from '../shared/Selector'
|
||||
import NumberInputSetting from './NumberInputSetting'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
export enum AIModel {
|
||||
LAMA = 'lama',
|
||||
LDM = 'ldm',
|
||||
}
|
||||
|
||||
function ModelSettingBlock() {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const onModelChange = (value: AIModel) => {
|
||||
setSettingState(old => {
|
||||
return { ...old, model: value }
|
||||
})
|
||||
}
|
||||
|
||||
const renderModelDesc = (
|
||||
name: string,
|
||||
paperUrl: string,
|
||||
githubUrl: string
|
||||
) => {
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
|
||||
<a
|
||||
className="model-desc-link"
|
||||
href={paperUrl}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
>
|
||||
{name}
|
||||
</a>
|
||||
|
||||
<a
|
||||
className="model-desc-link"
|
||||
href={githubUrl}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
>
|
||||
{githubUrl}
|
||||
</a>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderLDMModelDesc = () => {
|
||||
return (
|
||||
<div>
|
||||
{renderModelDesc(
|
||||
'High-Resolution Image Synthesis with Latent Diffusion Models',
|
||||
'https://arxiv.org/abs/2112.10752',
|
||||
'https://github.com/CompVis/latent-diffusion'
|
||||
)}
|
||||
<NumberInputSetting
|
||||
title="Steps"
|
||||
value={`${setting.ldmSteps}`}
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseInt(value, 10)
|
||||
setSettingState(old => {
|
||||
return { ...old, ldmSteps: val }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderOptionDesc = (): ReactNode => {
|
||||
switch (setting.model) {
|
||||
case AIModel.LAMA:
|
||||
return renderModelDesc(
|
||||
'Resolution-robust Large Mask Inpainting with Fourier Convolutions',
|
||||
'https://arxiv.org/abs/2109.07161',
|
||||
'https://github.com/saic-mdal/lama'
|
||||
)
|
||||
case AIModel.LDM:
|
||||
return renderLDMModelDesc()
|
||||
default:
|
||||
return <></>
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
className="model-setting-block"
|
||||
title="Inpainting Model"
|
||||
input={
|
||||
<Selector
|
||||
value={setting.model as string}
|
||||
options={Object.values(AIModel)}
|
||||
onChange={val => onModelChange(val as AIModel)}
|
||||
/>
|
||||
}
|
||||
optionDesc={renderOptionDesc()}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelSettingBlock
|
@ -0,0 +1,40 @@
|
||||
import React from 'react'
|
||||
import NumberInput from '../shared/NumberInput'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
interface NumberInputSettingProps {
|
||||
title: string
|
||||
value: string
|
||||
suffix?: string
|
||||
onValue: (val: string) => void
|
||||
}
|
||||
|
||||
function NumberInputSetting(props: NumberInputSettingProps) {
|
||||
const { title, value, suffix, onValue } = props
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
className="sub-setting-block"
|
||||
title={title}
|
||||
input={
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
gap: '8px',
|
||||
}}
|
||||
>
|
||||
<NumberInput
|
||||
style={{ width: '80px' }}
|
||||
value={`${value}`}
|
||||
onValue={onValue}
|
||||
/>
|
||||
{suffix && <span>{suffix}</span>}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default NumberInputSetting
|
@ -0,0 +1,31 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||
import SettingBlock from './SettingBlock'
|
||||
|
||||
function SavePathSettingBlock() {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const onCheckChange = (checked: boolean) => {
|
||||
setSettingState(old => {
|
||||
return { ...old, saveImageBesideOrigin: checked }
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
title="Download image beside origin image"
|
||||
input={
|
||||
<Switch
|
||||
checked={setting.saveImageBesideOrigin}
|
||||
onCheckedChange={onCheckChange}
|
||||
>
|
||||
<SwitchThumb />
|
||||
</Switch>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default SavePathSettingBlock
|
36
lama_cleaner/app/src/components/Settings/SettingBlock.scss
Normal file
@ -0,0 +1,36 @@
|
||||
.setting-block {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
.option-desc {
|
||||
color: var(--text-color-gray);
|
||||
margin-top: 12px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 0.3rem;
|
||||
padding: 1rem;
|
||||
|
||||
.sub-setting-block {
|
||||
margin-top: 8px;
|
||||
color: var(--text-color);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.setting-block-content {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
gap: 12rem;
|
||||
}
|
||||
|
||||
.setting-block-content-title {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.setting-block-desc {
|
||||
font-size: 1rem;
|
||||
margin-top: 8px;
|
||||
color: var(--text-color-gray);
|
||||
}
|
27
lama_cleaner/app/src/components/Settings/SettingBlock.tsx
Normal file
@ -0,0 +1,27 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
|
||||
interface SettingBlockProps {
|
||||
title: string
|
||||
desc?: string
|
||||
input: ReactNode
|
||||
optionDesc?: ReactNode
|
||||
className?: string
|
||||
}
|
||||
|
||||
function SettingBlock(props: SettingBlockProps) {
|
||||
const { title, desc, input, optionDesc, className } = props
|
||||
return (
|
||||
<div className={`setting-block ${className}`}>
|
||||
<div className="setting-block-content">
|
||||
<div className="setting-block-content-title">
|
||||
<span>{title}</span>
|
||||
{desc && <span className="setting-block-desc">{desc}</span>}
|
||||
</div>
|
||||
{input}
|
||||
</div>
|
||||
{optionDesc && <div className="option-desc">{optionDesc}</div>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default SettingBlock
|
46
lama_cleaner/app/src/components/Settings/SettingIcon.tsx
Normal file
@ -0,0 +1,46 @@
|
||||
import React from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Button from '../shared/Button'
|
||||
|
||||
const SettingIcon = () => {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const onClick = () => {
|
||||
setSettingState({ ...setting, show: !setting.show })
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Button
|
||||
onClick={onClick}
|
||||
style={{ border: 0 }}
|
||||
icon={
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
role="img"
|
||||
width="28"
|
||||
height="28"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"
|
||||
/>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"
|
||||
/>
|
||||
</svg>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default SettingIcon
|
20
lama_cleaner/app/src/components/Settings/Settings.scss
Normal file
@ -0,0 +1,20 @@
|
||||
@use '../../styles/Mixins/' as *;
|
||||
@import './SettingBlock.scss';
|
||||
@import './HDSettingBlock.scss';
|
||||
@import './ModelSettingBlock.scss';
|
||||
|
||||
.modal-setting {
|
||||
grid-area: main-content;
|
||||
background-color: var(--modal-bg);
|
||||
color: var(--modal-text-color);
|
||||
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
||||
width: 700px;
|
||||
|
||||
@include mobile {
|
||||
display: grid;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
margin-top: -11rem;
|
||||
animation: slideDown 0.2s ease-out;
|
||||
}
|
||||
}
|
37
lama_cleaner/app/src/components/Settings/SettingsModal.tsx
Normal file
@ -0,0 +1,37 @@
|
||||
import React from 'react'
|
||||
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { settingState } from '../../store/Atoms'
|
||||
import Modal from '../shared/Modal'
|
||||
import HDSettingBlock from './HDSettingBlock'
|
||||
import ModelSettingBlock from './ModelSettingBlock'
|
||||
|
||||
interface SettingModalProps {
|
||||
onClose: () => void
|
||||
}
|
||||
export default function SettingModal(props: SettingModalProps) {
|
||||
const { onClose } = props
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
|
||||
const handleOnClose = () => {
|
||||
setSettingState(old => {
|
||||
return { ...old, show: false }
|
||||
})
|
||||
onClose()
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
onClose={handleOnClose}
|
||||
title="Settings"
|
||||
className="modal-setting"
|
||||
show={setting.show}
|
||||
>
|
||||
{/* It's not possible because this poses a security risk */}
|
||||
{/* https://stackoverflow.com/questions/34870711/download-a-file-at-different-location-using-html5 */}
|
||||
{/* <SavePathSettingBlock /> */}
|
||||
<ModelSettingBlock />
|
||||
<HDSettingBlock />
|
||||
</Modal>
|
||||
)
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
import React, { ReactNode } from 'react'
|
||||
import { useSetRecoilState } from 'recoil'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { shortcutsState } from '../../store/Atoms'
|
||||
import Modal from '../shared/Modal'
|
||||
|
||||
@ -19,13 +19,8 @@ function ShortCut(props: Shortcut) {
|
||||
)
|
||||
}
|
||||
|
||||
interface ShortcutsModalProps {
|
||||
show: boolean
|
||||
}
|
||||
|
||||
export default function ShortcutsModal(props: ShortcutsModalProps) {
|
||||
const { show } = props
|
||||
const setShortcutState = useSetRecoilState(shortcutsState)
|
||||
export default function ShortcutsModal() {
|
||||
const [shortcutsShow, setShortcutState] = useRecoilState(shortcutsState)
|
||||
|
||||
const shortcutStateHandler = () => {
|
||||
setShortcutState(false)
|
||||
@ -36,7 +31,7 @@ export default function ShortcutsModal(props: ShortcutsModalProps) {
|
||||
onClose={shortcutStateHandler}
|
||||
title="Hotkeys"
|
||||
className="modal-shortcuts"
|
||||
show={show}
|
||||
show={shortcutsShow}
|
||||
>
|
||||
<div className="shortcut-options">
|
||||
<ShortCut content="Enable multi-stroke mask drawing">
|
||||
|
@ -1,19 +1,99 @@
|
||||
import React from 'react'
|
||||
import { useRecoilValue } from 'recoil'
|
||||
import React, { useEffect } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import Editor from './Editor/Editor'
|
||||
import { shortcutsState } from '../store/Atoms'
|
||||
import ShortcutsModal from './Shortcuts/ShortcutsModal'
|
||||
import SettingModal from './Settings/SettingsModal'
|
||||
import Toast from './shared/Toast'
|
||||
import { Settings, settingState, toastState } from '../store/Atoms'
|
||||
import {
|
||||
currentModel,
|
||||
modelDownloaded,
|
||||
switchModel,
|
||||
} from '../adapters/inpainting'
|
||||
import { AIModel } from './Settings/ModelSettingBlock'
|
||||
|
||||
interface WorkspaceProps {
|
||||
file: File
|
||||
}
|
||||
|
||||
const Workspace = ({ file }: WorkspaceProps) => {
|
||||
const shortcutVisbility = useRecoilValue(shortcutsState)
|
||||
const [settings, setSettingState] = useRecoilState(settingState)
|
||||
const [toastVal, setToastState] = useRecoilState(toastState)
|
||||
|
||||
const onSettingClose = async () => {
|
||||
const curModel = await currentModel().then(res => res.text())
|
||||
if (curModel === settings.model) {
|
||||
return
|
||||
}
|
||||
const downloaded = await modelDownloaded(settings.model).then(res =>
|
||||
res.text()
|
||||
)
|
||||
|
||||
const { model } = settings
|
||||
|
||||
let loadingMessage = `Switching to ${model} model`
|
||||
let loadingDuration = 3000
|
||||
if (downloaded === 'False') {
|
||||
loadingMessage = `Downloading ${model} model, this may take a while`
|
||||
loadingDuration = 9999999999
|
||||
}
|
||||
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: loadingMessage,
|
||||
state: 'loading',
|
||||
duration: loadingDuration,
|
||||
})
|
||||
|
||||
switchModel(model)
|
||||
.then(res => {
|
||||
if (res.ok) {
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: `Switch to ${model} model success`,
|
||||
state: 'success',
|
||||
duration: 3000,
|
||||
})
|
||||
} else {
|
||||
throw new Error('Server error')
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: `Switch to ${model} model failed`,
|
||||
state: 'error',
|
||||
duration: 3000,
|
||||
})
|
||||
setSettingState(old => {
|
||||
return { ...old, model: curModel as AIModel }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
currentModel()
|
||||
.then(res => res.text())
|
||||
.then(model => {
|
||||
setSettingState(old => {
|
||||
return { ...old, model: model as AIModel }
|
||||
})
|
||||
})
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<>
|
||||
<Editor file={file} />
|
||||
<ShortcutsModal show={shortcutVisbility} />
|
||||
<SettingModal onClose={onSettingClose} />
|
||||
<ShortcutsModal />
|
||||
<Toast
|
||||
{...toastVal}
|
||||
onOpenChange={(open: boolean) => {
|
||||
setToastState(old => {
|
||||
return { ...old, open }
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { XIcon } from '@heroicons/react/outline'
|
||||
import React, { ReactNode, useRef } from 'react'
|
||||
import { useClickAway, useKey } from 'react-use'
|
||||
import { useClickAway, useKey, useKeyPress, useKeyPressEvent } from 'react-use'
|
||||
import Button from './Button'
|
||||
|
||||
export interface ModalProps {
|
||||
@ -16,11 +16,15 @@ export default function Modal(props: ModalProps) {
|
||||
const ref = useRef(null)
|
||||
|
||||
useClickAway(ref, () => {
|
||||
onClose?.()
|
||||
if (show) {
|
||||
onClose?.()
|
||||
}
|
||||
})
|
||||
|
||||
useKey('Escape', onClose, {
|
||||
event: 'keydown',
|
||||
useKeyPressEvent('Escape', e => {
|
||||
if (show) {
|
||||
onClose?.()
|
||||
}
|
||||
})
|
||||
|
||||
return (
|
||||
@ -30,7 +34,7 @@ export default function Modal(props: ModalProps) {
|
||||
>
|
||||
<div ref={ref} className={`modal ${className}`}>
|
||||
<div className="modal-header">
|
||||
<h3>{title}</h3>
|
||||
<h2>{title}</h2>
|
||||
<Button icon={<XIcon />} onClick={onClose} />
|
||||
</div>
|
||||
{children}
|
||||
|
12
lama_cleaner/app/src/components/shared/NumberInput.scss
Normal file
@ -0,0 +1,12 @@
|
||||
.number-input {
|
||||
all: unset;
|
||||
flex: 1 0 auto;
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.2rem 0.8rem;
|
||||
line-height: 1;
|
||||
outline: 1px solid var(--border-color);
|
||||
|
||||
&:focus-visible {
|
||||
outline: 1px solid var(--yellow-accent);
|
||||
}
|
||||
}
|
31
lama_cleaner/app/src/components/shared/NumberInput.tsx
Normal file
@ -0,0 +1,31 @@
|
||||
import React, { FormEvent, InputHTMLAttributes } from 'react'
|
||||
|
||||
interface NumberInputProps extends InputHTMLAttributes<HTMLInputElement> {
|
||||
value: string
|
||||
onValue?: (val: string) => void
|
||||
}
|
||||
|
||||
const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
|
||||
(props: NumberInputProps, forwardedRef) => {
|
||||
const { value, onValue, ...itemProps } = props
|
||||
|
||||
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
|
||||
const target = evt.target as HTMLInputElement
|
||||
const val = target.value.replace(/\D/g, '')
|
||||
onValue?.(val)
|
||||
}
|
||||
|
||||
return (
|
||||
<input
|
||||
value={value}
|
||||
onInput={handleOnInput}
|
||||
className="number-input"
|
||||
{...itemProps}
|
||||
ref={forwardedRef}
|
||||
type="text"
|
||||
/>
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
export default NumberInput
|
74
lama_cleaner/app/src/components/shared/Selector.scss
Normal file
@ -0,0 +1,74 @@
|
||||
@use '../../styles/Mixins' as *;
|
||||
|
||||
.selector {
|
||||
position: relative;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.selector-main {
|
||||
@include accented-display(var(white));
|
||||
width: 100%;
|
||||
user-select: none;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
cursor: pointer;
|
||||
outline: none;
|
||||
gap: 8px;
|
||||
padding: 0.2rem 0.8rem;
|
||||
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--options-text-color);
|
||||
|
||||
svg {
|
||||
width: 1rem;
|
||||
height: 1rem;
|
||||
margin-top: 0.25rem;
|
||||
}
|
||||
}
|
||||
|
||||
.selector-options {
|
||||
@include accented-display(var(--btn-primary-bg));
|
||||
width: 100%;
|
||||
padding: 0;
|
||||
display: grid;
|
||||
justify-self: center;
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 3rem;
|
||||
|
||||
color: var(--options-text-color);
|
||||
background-color: var(--page-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
|
||||
border-radius: 0.6rem;
|
||||
|
||||
@include mobile {
|
||||
bottom: 11.5rem;
|
||||
}
|
||||
|
||||
.selector-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
user-select: none;
|
||||
padding: 0.5rem 0.8rem;
|
||||
|
||||
&:first-of-type {
|
||||
border-top-right-radius: 0.5rem;
|
||||
border-top-left-radius: 0.5rem;
|
||||
}
|
||||
|
||||
&:last-of-type {
|
||||
border-bottom-left-radius: 0.5rem;
|
||||
border-bottom-right-radius: 0.5rem;
|
||||
}
|
||||
|
||||
&:hover {
|
||||
background-color: var(--yellow-accent);
|
||||
color: var(--btn-text-hover-color);
|
||||
}
|
||||
}
|
||||
}
|
86
lama_cleaner/app/src/components/shared/Selector.tsx
Normal file
@ -0,0 +1,86 @@
|
||||
import React, { MutableRefObject, useCallback, useRef, useState } from 'react'
|
||||
import { useClickAway, useKeyPressEvent } from 'react-use'
|
||||
import { ChevronDownIcon, ChevronUpIcon } from '@heroicons/react/outline'
|
||||
|
||||
type SelectorChevronDirection = 'up' | 'down'
|
||||
|
||||
type SelectorProps = {
|
||||
minWidth?: number
|
||||
chevronDirection?: SelectorChevronDirection
|
||||
value: string
|
||||
options: string[]
|
||||
onChange: (value: string) => void
|
||||
}
|
||||
|
||||
const selectorDefaultProps = {
|
||||
minWidth: 128,
|
||||
chevronDirection: 'down',
|
||||
}
|
||||
|
||||
function Selector(props: SelectorProps) {
|
||||
const { minWidth, chevronDirection, value, options, onChange } = props
|
||||
const [showOptions, setShowOptions] = useState<boolean>(false)
|
||||
const selectorRef = useRef<HTMLDivElement | null>(null)
|
||||
|
||||
const showOptionsHandler = () => {
|
||||
// console.log(selectorRef.current?.focus)
|
||||
// selectorRef?.current?.focus()
|
||||
setShowOptions(currentShowOptionsState => !currentShowOptionsState)
|
||||
}
|
||||
|
||||
useClickAway(selectorRef, () => {
|
||||
setShowOptions(false)
|
||||
})
|
||||
|
||||
// TODO: how to prevent Modal close?
|
||||
// useKeyPressEvent('Escape', (e: KeyboardEvent) => {
|
||||
// if (showOptions === true) {
|
||||
// console.log(`selector ${e}`)
|
||||
// e.preventDefault()
|
||||
// e.stopPropagation()
|
||||
// setShowOptions(false)
|
||||
// }
|
||||
// })
|
||||
|
||||
const onOptionClick = (e: any, newIndex: number) => {
|
||||
const currentRes = e.target.textContent.split('x')
|
||||
onChange(currentRes[0])
|
||||
setShowOptions(false)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="selector" ref={selectorRef} style={{ minWidth }}>
|
||||
<div
|
||||
className="selector-main"
|
||||
role="button"
|
||||
onClick={showOptionsHandler}
|
||||
aria-hidden="true"
|
||||
>
|
||||
<p>{value}</p>
|
||||
<div className="selector-icon">
|
||||
{chevronDirection === 'up' ? <ChevronUpIcon /> : <ChevronDownIcon />}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{showOptions && (
|
||||
<div className="selector-options">
|
||||
{options.map((val, _index) => (
|
||||
<div
|
||||
className="selector-option"
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
key={val}
|
||||
onClick={e => onOptionClick(e, _index)}
|
||||
aria-hidden="true"
|
||||
>
|
||||
{val}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
Selector.defaultProps = selectorDefaultProps
|
||||
export default Selector
|
36
lama_cleaner/app/src/components/shared/Switch.scss
Normal file
@ -0,0 +1,36 @@
|
||||
.switch-root {
|
||||
all: 'unset';
|
||||
width: 42px;
|
||||
height: 25px;
|
||||
background-color: var(--switch-root-background-color);
|
||||
border-radius: 9999px;
|
||||
border: none;
|
||||
position: relative;
|
||||
transition: background-color 100ms;
|
||||
-webkit-tap-highlight-color: rgba(0, 0, 0, 0);
|
||||
|
||||
&:focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
}
|
||||
|
||||
.switch-root[data-state='checked'] {
|
||||
background-color: var(--yellow-accent);
|
||||
}
|
||||
|
||||
.switch-thumb {
|
||||
display: block;
|
||||
width: 17px;
|
||||
height: 17px;
|
||||
background-color: var(--switch-thumb-color);
|
||||
border-radius: 9999px;
|
||||
transition: transform 100ms;
|
||||
transform: translateX(4px);
|
||||
will-change: transform;
|
||||
}
|
||||
|
||||
.switch-thumb[data-state='checked'] {
|
||||
transform: translateX(21px);
|
||||
background-color: var(--switch-thumb-checked-color);
|
||||
outline: 1px solid rgb(100, 100, 120, 0.5);
|
||||
}
|
34
lama_cleaner/app/src/components/shared/Switch.tsx
Normal file
@ -0,0 +1,34 @@
|
||||
import React from 'react'
|
||||
import * as SwitchPrimitive from '@radix-ui/react-switch'
|
||||
|
||||
const Switch = React.forwardRef<
|
||||
React.ElementRef<typeof SwitchPrimitive.Root>,
|
||||
React.ComponentProps<typeof SwitchPrimitive.Root>
|
||||
>((props, forwardedRef) => {
|
||||
const { className, ...itemProps } = props
|
||||
|
||||
return (
|
||||
<SwitchPrimitive.Root
|
||||
{...itemProps}
|
||||
ref={forwardedRef}
|
||||
className={`switch-root ${className}`}
|
||||
/>
|
||||
)
|
||||
})
|
||||
|
||||
const SwitchThumb = React.forwardRef<
|
||||
React.ElementRef<typeof SwitchPrimitive.Thumb>,
|
||||
React.ComponentProps<typeof SwitchPrimitive.Thumb>
|
||||
>((props, forwardedRef) => {
|
||||
const { className, ...itemProps } = props
|
||||
|
||||
return (
|
||||
<SwitchPrimitive.Thumb
|
||||
{...itemProps}
|
||||
ref={forwardedRef}
|
||||
className={`switch-thumb ${className}`}
|
||||
/>
|
||||
)
|
||||
})
|
||||
|
||||
export { Switch, SwitchThumb }
|
83
lama_cleaner/app/src/components/shared/Toast.scss
Normal file
@ -0,0 +1,83 @@
|
||||
.toast-viewpoint {
|
||||
position: fixed;
|
||||
top: 48px;
|
||||
right: 0;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
padding: 25px;
|
||||
gap: 10px;
|
||||
max-width: 100vw;
|
||||
margin: 0;
|
||||
z-index: 999999;
|
||||
|
||||
&:focus-visible {
|
||||
outline: none;
|
||||
}
|
||||
}
|
||||
|
||||
.toast-root {
|
||||
border: 1px solid var(--border-color-light);
|
||||
background-color: var(--page-bg);
|
||||
border-radius: 0.6rem;
|
||||
padding: 15px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
gap: 12px;
|
||||
|
||||
&[data-state='open'] {
|
||||
animation: slideIn 150ms cubic-bezier(0.16, 1, 0.3, 1);
|
||||
}
|
||||
|
||||
&[data-state='close'] {
|
||||
animation: opacityReveal 100ms ease-in forwards;
|
||||
}
|
||||
|
||||
&[data-state='cancel'] {
|
||||
transform: translateX(0);
|
||||
animation: transform 100ms ease-out;
|
||||
}
|
||||
|
||||
&.error {
|
||||
border: 1px solid var(--error-color);
|
||||
}
|
||||
|
||||
&.success {
|
||||
border: 1px solid var(--success-color);
|
||||
}
|
||||
}
|
||||
|
||||
.error-icon {
|
||||
height: 24px;
|
||||
width: 24px;
|
||||
color: var(--error-color);
|
||||
}
|
||||
|
||||
.success-icon {
|
||||
height: 24px;
|
||||
width: 24px;
|
||||
color: var(--success-color);
|
||||
}
|
||||
|
||||
.loading-icon {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
animation-name: spin;
|
||||
animation-duration: 1500ms;
|
||||
animation-iteration-count: infinite;
|
||||
transform-origin: center center;
|
||||
animation-timing-function: linear;
|
||||
}
|
||||
|
||||
.toast-icon {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.toast-desc {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 0;
|
||||
color: var(--text-color);
|
||||
min-width: 240px;
|
||||
}
|
81
lama_cleaner/app/src/components/shared/Toast.tsx
Normal file
@ -0,0 +1,81 @@
|
||||
import * as React from 'react'
|
||||
import * as ToastPrimitive from '@radix-ui/react-toast'
|
||||
import { ToastProps } from '@radix-ui/react-toast'
|
||||
import { CheckIcon, ExclamationCircleIcon } from '@heroicons/react/outline'
|
||||
|
||||
const LoadingIcon = () => {
|
||||
return (
|
||||
<span className="loading-icon">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
>
|
||||
<line x1="12" y1="2" x2="12" y2="6" />
|
||||
<line x1="12" y1="18" x2="12" y2="22" />
|
||||
<line x1="4.93" y1="4.93" x2="7.76" y2="7.76" />
|
||||
<line x1="16.24" y1="16.24" x2="19.07" y2="19.07" />
|
||||
<line x1="2" y1="12" x2="6" y2="12" />
|
||||
<line x1="18" y1="12" x2="22" y2="12" />
|
||||
<line x1="4.93" y1="19.07" x2="7.76" y2="16.24" />
|
||||
<line x1="16.24" y1="7.76" x2="19.07" y2="4.93" />
|
||||
</svg>
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
export type ToastState = 'default' | 'error' | 'loading' | 'success'
|
||||
|
||||
interface MyToastProps extends ToastProps {
|
||||
desc: string
|
||||
state?: ToastState
|
||||
}
|
||||
|
||||
const Toast = React.forwardRef<
|
||||
React.ElementRef<typeof ToastPrimitive.Root>,
|
||||
MyToastProps
|
||||
>((props, forwardedRef) => {
|
||||
const { state, desc, ...itemProps } = props
|
||||
|
||||
const getIcon = () => {
|
||||
switch (state) {
|
||||
case 'error':
|
||||
return <ExclamationCircleIcon className="error-icon" />
|
||||
case 'success':
|
||||
return <CheckIcon className="success-icon" />
|
||||
case 'loading':
|
||||
return <LoadingIcon />
|
||||
default:
|
||||
return <></>
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<ToastPrimitive.Provider>
|
||||
<ToastPrimitive.Root
|
||||
{...itemProps}
|
||||
ref={forwardedRef}
|
||||
className={`toast-root ${state}`}
|
||||
>
|
||||
<div className="toast-icon">{getIcon()}</div>
|
||||
<ToastPrimitive.Description className="toast-desc">
|
||||
{desc}
|
||||
</ToastPrimitive.Description>
|
||||
</ToastPrimitive.Root>
|
||||
<ToastPrimitive.Viewport className="toast-viewpoint" />
|
||||
</ToastPrimitive.Provider>
|
||||
)
|
||||
})
|
||||
|
||||
Toast.defaultProps = {
|
||||
desc: '',
|
||||
state: 'loading',
|
||||
}
|
||||
|
||||
export default Toast
|
@ -8,14 +8,21 @@ export default function useInputImage() {
|
||||
headers.append('pragma', 'no-cache')
|
||||
headers.append('cache-control', 'no-cache')
|
||||
|
||||
fetch('/inputimage', { headers })
|
||||
.then(res => res.blob())
|
||||
.then(data => {
|
||||
if (data && data.type.startsWith('image')) {
|
||||
const userInput = new File([data], 'inputImage')
|
||||
setInputImage(userInput)
|
||||
}
|
||||
})
|
||||
fetch('/inputimage', { headers }).then(async res => {
|
||||
const filename = res.headers
|
||||
.get('content-disposition')
|
||||
?.split('filename=')[1]
|
||||
.split(';')[0]
|
||||
|
||||
const data = await res.blob()
|
||||
if (data && data.type.startsWith('image')) {
|
||||
const userInput = new File(
|
||||
[data],
|
||||
filename !== undefined ? filename : 'inputImage'
|
||||
)
|
||||
setInputImage(userInput)
|
||||
}
|
||||
})
|
||||
}, [setInputImage])
|
||||
|
||||
useEffect(() => {
|
||||
|
@ -1,11 +1,82 @@
|
||||
import { atom } from 'recoil'
|
||||
import { HDStrategy } from '../components/Settings/HDSettingBlock'
|
||||
import { AIModel } from '../components/Settings/ModelSettingBlock'
|
||||
import { ToastState } from '../components/shared/Toast'
|
||||
|
||||
export const fileState = atom<File | undefined>({
|
||||
key: 'fileState',
|
||||
default: undefined,
|
||||
})
|
||||
|
||||
interface ToastAtomState {
|
||||
open: boolean
|
||||
desc: string
|
||||
state: ToastState
|
||||
duration: number
|
||||
}
|
||||
|
||||
export const toastState = atom<ToastAtomState>({
|
||||
key: 'toastState',
|
||||
default: {
|
||||
open: false,
|
||||
desc: '',
|
||||
state: 'default',
|
||||
duration: 3000,
|
||||
},
|
||||
})
|
||||
|
||||
export const shortcutsState = atom<boolean>({
|
||||
key: 'shortcutsState',
|
||||
default: false,
|
||||
})
|
||||
|
||||
export interface Settings {
|
||||
show: boolean
|
||||
saveImageBesideOrigin: boolean
|
||||
model: AIModel
|
||||
|
||||
// For LaMa
|
||||
hdStrategy: HDStrategy
|
||||
hdStrategyResizeLimit: number
|
||||
hdStrategyCropTrigerSize: number
|
||||
hdStrategyCropMargin: number
|
||||
|
||||
// For LDM
|
||||
ldmSteps: number
|
||||
}
|
||||
|
||||
export const settingStateDefault = {
|
||||
show: false,
|
||||
saveImageBesideOrigin: false,
|
||||
model: AIModel.LAMA,
|
||||
ldmSteps: 50,
|
||||
hdStrategy: HDStrategy.RESIZE,
|
||||
hdStrategyResizeLimit: 2048,
|
||||
hdStrategyCropTrigerSize: 2048,
|
||||
hdStrategyCropMargin: 128,
|
||||
}
|
||||
|
||||
const localStorageEffect =
|
||||
(key: string) =>
|
||||
({ setSelf, onSet }: any) => {
|
||||
const savedValue = localStorage.getItem(key)
|
||||
if (savedValue != null) {
|
||||
const storageSettings = JSON.parse(savedValue)
|
||||
storageSettings.show = false
|
||||
setSelf(storageSettings)
|
||||
}
|
||||
|
||||
onSet((newValue: Settings, _: string, isReset: boolean) =>
|
||||
isReset
|
||||
? localStorage.removeItem(key)
|
||||
: localStorage.setItem(key, JSON.stringify(newValue))
|
||||
)
|
||||
}
|
||||
|
||||
// Each atom can reference an array of these atom effect functions which are called in priority order when the atom is initialized
|
||||
// https://recoiljs.org/docs/guides/atom-effects/#local-storage-persistence
|
||||
export const settingState = atom<Settings>({
|
||||
key: 'settingsState',
|
||||
default: settingStateDefault,
|
||||
effects: [localStorageEffect('settingsState')],
|
||||
})
|
||||
|
@ -37,3 +37,21 @@
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
0% {
|
||||
transform: translateX(calc(100% + 25px));
|
||||
}
|
||||
100% {
|
||||
transform: translateX(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,10 @@
|
||||
--yellow-accent: #ffcc00;
|
||||
--link-color: rgb(0, 0, 0);
|
||||
--border-color: rgb(100, 100, 120);
|
||||
--border-color-light: rgba(100, 100, 120, 0.5);
|
||||
|
||||
--error-color: rgb(239, 68, 68);
|
||||
--success-color: rgb(16, 185, 129);
|
||||
|
||||
// Editor
|
||||
--editor-toolkit-bg: rgba(255, 255, 255, 0.5);
|
||||
@ -17,13 +21,22 @@
|
||||
--modal-bg: var(--page-bg);
|
||||
--modal-text-color: rgb(0, 0, 0);
|
||||
--modal-hotkey-border-color: rgb(0, 0, 0);
|
||||
--model-mask-bg: rgba(209,213,219,0.4);
|
||||
--model-mask-bg: rgba(209, 213, 219, 0.4);
|
||||
|
||||
// Text
|
||||
--text-color: #040404;
|
||||
--text-color-gray: rgb(107, 111, 118);
|
||||
|
||||
// Shared
|
||||
--btn-primary-bg: rgb(210, 210, 220);
|
||||
--btn-text-color: black;
|
||||
--btn-text-hover-color: black;
|
||||
--btn-text-color: var(--text-color);
|
||||
--btn-text-hover-color: #040404;
|
||||
--btn-border-color: rgb(100, 100, 120);
|
||||
--btn-primary-hover-bg: var(--yellow-accent);
|
||||
--animation-pulsing-bg: rgb(255, 255, 255, 0.5);
|
||||
|
||||
// switch
|
||||
--switch-root-background-color: rgb(223, 225, 228);
|
||||
--switch-thumb-color: var(--page-bg);
|
||||
--switch-thumb-checked-color: var(--page-bg);
|
||||
}
|
||||
|
@ -3,10 +3,11 @@
|
||||
|
||||
// Theme
|
||||
--page-bg: #040404;
|
||||
--page-text-color: #F9F9F9;
|
||||
--page-text-color: #f9f9f9;
|
||||
--yellow-accent: #ffcc00;
|
||||
--link-color: var(--yellow-accent);
|
||||
--border-color: rgb(100, 100, 120);
|
||||
--border-color-light: rgba(102, 102, 102);
|
||||
|
||||
// Editor
|
||||
--editor-toolkit-bg: rgba(0, 0, 0, 0.5);
|
||||
@ -17,14 +18,23 @@
|
||||
--modal-bg: var(--page-bg);
|
||||
--modal-text-color: var(--page-text-color);
|
||||
// --modal-hotkey-bg: rgb(60, 60, 90);
|
||||
--modal-hotkey-border-color: var(--page-text-color);;
|
||||
--modal-hotkey-border-color: var(--page-text-color);
|
||||
--model-mask-bg: rgba(76, 76, 87, 0.4);
|
||||
|
||||
// Text
|
||||
--text-color: white;
|
||||
--text-color-gray: rgb(138, 143, 152);
|
||||
|
||||
// Shared
|
||||
--btn-primary-bg: rgb(140, 140, 180);
|
||||
--btn-text-color: white;
|
||||
--btn-text-color: var(--text-color);
|
||||
--btn-text-hover-color: var(--page-bg);
|
||||
--btn-border-color: var(--yellow-accent);
|
||||
--btn-primary-hover-bg: var(--yellow-accent);
|
||||
--animation-pulsing-bg: rgb(240, 240, 255);
|
||||
|
||||
// switch
|
||||
--switch-root-background-color: rgb(60, 63, 68);
|
||||
--switch-thumb-color: rgb(31, 32, 35);
|
||||
--switch-thumb-checked-color: white;
|
||||
}
|
||||
|
@ -11,11 +11,16 @@
|
||||
@use '../components/Header/Header';
|
||||
@use '../components/Header/ThemeChanger';
|
||||
@use '../components/Shortcuts/Shortcuts';
|
||||
@use '../components/Settings/Settings.scss';
|
||||
|
||||
// Shared
|
||||
@use '../components/FileSelect/FileSelect';
|
||||
@use '../components/shared/Button';
|
||||
@use '../components/shared/Modal';
|
||||
@use '../components/shared/Selector';
|
||||
@use '../components/shared/Switch';
|
||||
@use '../components/shared/NumberInput';
|
||||
@use '../components/shared/Toast';
|
||||
|
||||
// Main CSS
|
||||
*,
|
||||
|
@ -1164,6 +1164,13 @@
|
||||
dependencies:
|
||||
regenerator-runtime "^0.13.4"
|
||||
|
||||
"@babel/runtime@^7.13.10":
|
||||
version "7.17.9"
|
||||
resolved "https://registry.npmmirror.com/@babel/runtime/-/runtime-7.17.9.tgz#d19fbf802d01a8cb6cf053a64e472d42c434ba72"
|
||||
integrity sha512-lSiBBvodq29uShpWGNbgFdKYNiFDo5/HIYsaCEY9ff4sb10x9jizo2+pRrSyF4jKZCXqgzuqBOQKbUm90gQwJg==
|
||||
dependencies:
|
||||
regenerator-runtime "^0.13.4"
|
||||
|
||||
"@babel/template@^7.10.4", "@babel/template@^7.15.4", "@babel/template@^7.3.3":
|
||||
version "7.15.4"
|
||||
resolved "https://registry.npmjs.org/@babel/template/-/template-7.15.4.tgz"
|
||||
@ -1537,6 +1544,186 @@
|
||||
schema-utils "^2.6.5"
|
||||
source-map "^0.7.3"
|
||||
|
||||
"@radix-ui/primitive@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-0.1.0.tgz#6206b97d379994f0d1929809db035733b337e543"
|
||||
integrity sha512-tqxZKybwN5Fa3VzZry4G6mXAAb9aAqKmPtnVbZpL0vsBwvOHTBwsjHVPXylocYLwEtBY9SCe665bYnNB515uoA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-compose-refs@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-compose-refs/-/react-compose-refs-0.1.0.tgz#cff6e780a0f73778b976acff2c2a5b6551caab95"
|
||||
integrity sha512-eyclbh+b77k+69Dk72q3694OHrn9B3QsoIRx7ywX341U9RK1ThgQjMFZoPtmZNQTksXHLNEiefR8hGVeFyInGg==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-context@0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-0.1.1.tgz#06996829ea124d9a1bc1dbe3e51f33588fab0875"
|
||||
integrity sha512-PkyVX1JsLBioeu0jB9WvRpDBBLtLZohVDT3BB5CTSJqActma8S8030P57mWZb4baZifMvN7KKWPAA40UmWKkQg==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-dismissable-layer@0.1.5":
|
||||
version "0.1.5"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-0.1.5.tgz#9379032351e79028d472733a5cc8ba4a0ea43314"
|
||||
integrity sha512-J+fYWijkX4M4QKwf9dtu1oC0U6e6CEl8WhBp3Ad23yz2Hia0XCo6Pk/mp5CAFy4QBtQedTSkhW05AdtSOEoajQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/primitive" "0.1.0"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-body-pointer-events" "0.1.1"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
"@radix-ui/react-use-escape-keydown" "0.1.0"
|
||||
|
||||
"@radix-ui/react-id@0.1.5":
|
||||
version "0.1.5"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
|
||||
integrity sha512-IPc4H/63bes0IZ1GJJozSEkSWcDyhNGtKFWUpJ+XtaLyQ1X3x7Mf6fWwWhDcpqlYEP+5WtAvfqcyEsyjP+ZhBQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-label@0.1.5":
|
||||
version "0.1.5"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-label/-/react-label-0.1.5.tgz#12cd965bfc983e0148121d4c99fb8e27a917c45c"
|
||||
integrity sha512-Au9+n4/DhvjR0IHhvZ1LPdx/OW+3CGDie30ZyCkbSHIuLp4/CV4oPPGBwJ1vY99Jog3zyQhsGww9MXj8O9Aj/A==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-context" "0.1.1"
|
||||
"@radix-ui/react-id" "0.1.5"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
|
||||
"@radix-ui/react-portal@0.1.4":
|
||||
version "0.1.4"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2"
|
||||
integrity sha512-MO0wRy2eYRTZ/CyOri9NANCAtAtq89DEtg90gicaTlkCfdqCLEBsLb+/q66BZQTr3xX/Vq01nnVfc/TkCqoqvw==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-presence@0.1.2":
|
||||
version "0.1.2"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3"
|
||||
integrity sha512-3BRlFZraooIUfRlyN+b/Xs5hq1lanOOo/+3h6Pwu2GMFjkGKKa4Rd51fcqGqnVlbr3jYg+WLuGyAV4KlgqwrQw==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-primitive@0.1.4":
|
||||
version "0.1.4"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
|
||||
integrity sha512-6gSl2IidySupIMJFjYnDIkIWRyQdbu/AHK7rbICPani+LW4b0XdxBXc46og/iZvuwW8pjCS8I2SadIerv84xYA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-slot" "0.1.2"
|
||||
|
||||
"@radix-ui/react-slot@0.1.2":
|
||||
version "0.1.2"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-0.1.2.tgz#e6f7ad9caa8ce81cc8d532c854c56f9b8b6307c8"
|
||||
integrity sha512-ADkqfL+agEzEguU3yS26jfB50hRrwf7U4VTwAOZEmi/g+ITcBWe12yM46ueS/UCIMI9Py+gFUaAdxgxafFvY2Q==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
|
||||
"@radix-ui/react-switch@^0.1.5":
|
||||
version "0.1.5"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-switch/-/react-switch-0.1.5.tgz#071ffa19a17a47fdc5c5e6f371bd5901c9fef2f4"
|
||||
integrity sha512-ITtslJPK+Yi34iNf7K9LtsPaLD76oRIVzn0E8JpEO5HW8gpRBGb2NNI9mxKtEB30TVqIcdjdL10AmuIfOMwjtg==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/primitive" "0.1.0"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-context" "0.1.1"
|
||||
"@radix-ui/react-label" "0.1.5"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-controllable-state" "0.1.0"
|
||||
"@radix-ui/react-use-previous" "0.1.1"
|
||||
"@radix-ui/react-use-size" "0.1.1"
|
||||
|
||||
"@radix-ui/react-toast@^0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-toast/-/react-toast-0.1.1.tgz#d544e796b307e56f1298e40f356f468680958e93"
|
||||
integrity sha512-9JWC4mPP78OE6muDrpaPf/71dIeozppdcnik1IvsjTxZpDnt9PbTtQj94DdWjlCphbv3S5faD3KL0GOpqKBpTQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/primitive" "0.1.0"
|
||||
"@radix-ui/react-compose-refs" "0.1.0"
|
||||
"@radix-ui/react-context" "0.1.1"
|
||||
"@radix-ui/react-dismissable-layer" "0.1.5"
|
||||
"@radix-ui/react-portal" "0.1.4"
|
||||
"@radix-ui/react-presence" "0.1.2"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
"@radix-ui/react-use-controllable-state" "0.1.0"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
"@radix-ui/react-visually-hidden" "0.1.4"
|
||||
|
||||
"@radix-ui/react-use-body-pointer-events@0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-body-pointer-events/-/react-use-body-pointer-events-0.1.1.tgz#63e7fd81ca7ffd30841deb584cd2b7f460df2597"
|
||||
integrity sha512-R8leV2AWmJokTmERM8cMXFHWSiv/fzOLhG/JLmRBhLTAzOj37EQizssq4oW0Z29VcZy2tODMi9Pk/htxwb+xpA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-layout-effect" "0.1.0"
|
||||
|
||||
"@radix-ui/react-use-callback-ref@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-0.1.0.tgz#934b6e123330f5b3a6b116460e6662cbc663493f"
|
||||
integrity sha512-Va041McOFFl+aV+sejvl0BS2aeHx86ND9X/rVFmEFQKTXCp6xgUK0NGUAGcgBlIjnJSbMYPGEk1xKSSlVcN2Aw==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-use-controllable-state@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-0.1.0.tgz#4fced164acfc69a4e34fb9d193afdab973a55de1"
|
||||
integrity sha512-zv7CX/PgsRl46a52Tl45TwqwVJdmqnlQEQhaYMz/yBOD2sx2gCkCFSoF/z9mpnYWmS6DTLNTg5lIps3fV6EnXg==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
|
||||
"@radix-ui/react-use-escape-keydown@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874"
|
||||
integrity sha512-tDLZbTGFmvXaazUXXv8kYbiCcbAE8yKgng9s95d8fCO+Eundv0Jngbn/hKPhDDs4jj9ChwRX5cDDnlaN+ugYYQ==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-use-callback-ref" "0.1.0"
|
||||
|
||||
"@radix-ui/react-use-layout-effect@0.1.0":
|
||||
version "0.1.0"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
|
||||
integrity sha512-+wdeS51Y+E1q1Wmd+1xSSbesZkpVj4jsg0BojCbopWvgq5iBvixw5vgemscdh58ep98BwUbsFYnrywFhV9yrVg==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-use-previous@0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-previous/-/react-use-previous-0.1.1.tgz#0226017f72267200f6e832a7103760e96a6db5d0"
|
||||
integrity sha512-O/ZgrDBr11dR8rhO59ED8s5zIXBRFi8MiS+CmFGfi7MJYdLbfqVOmQU90Ghf87aifEgWe6380LA69KBneaShAg==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-use-size@0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-use-size/-/react-use-size-0.1.1.tgz#f6b75272a5d41c3089ca78c8a2e48e5f204ef90f"
|
||||
integrity sha512-pTgWM5qKBu6C7kfKxrKPoBI2zZYZmp2cSXzpUiGM3qEBQlMLtYhaY2JXdXUCxz+XmD1YEjc8oRwvyfsD4AG4WA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
|
||||
"@radix-ui/react-visually-hidden@0.1.4":
|
||||
version "0.1.4"
|
||||
resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03"
|
||||
integrity sha512-K/q6AEEzqeeEq/T0NPChvBqnwlp8Tl4NnQdrI/y8IOY7BRR+Ug0PEsVk6g48HJ7cA1//COugdxXXVVK/m0X1mA==
|
||||
dependencies:
|
||||
"@babel/runtime" "^7.13.10"
|
||||
"@radix-ui/react-primitive" "0.1.4"
|
||||
|
||||
"@rollup/plugin-node-resolve@^7.1.1":
|
||||
version "7.1.3"
|
||||
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
|
||||
|
||||
def download_model(url):
|
||||
def get_cache_path_by_url(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
@ -17,6 +17,11 @@ def download_model(url):
|
||||
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
return cached_file
|
||||
|
||||
|
||||
def download_model(url):
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
@ -31,7 +36,11 @@ def ceil_modulo(x, mod):
|
||||
|
||||
|
||||
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||
data = cv2.imencode(f".{ext}", image_numpy)[1]
|
||||
data = cv2.imencode(f".{ext}", image_numpy,
|
||||
[
|
||||
int(cv2.IMWRITE_JPEG_QUALITY), 100,
|
||||
int(cv2.IMWRITE_PNG_COMPRESSION), 0
|
||||
])[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
@ -74,13 +83,24 @@ def resize_max_size(
|
||||
return np_img
|
||||
|
||||
|
||||
def pad_img_to_modulo(img, mod):
|
||||
channels, height, width = img.shape
|
||||
def pad_img_to_modulo(img: np.ndarray, mod: int):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img: [H, W, C]
|
||||
mod:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
height, width = img.shape[:2]
|
||||
out_height = ceil_modulo(height, mod)
|
||||
out_width = ceil_modulo(width, mod)
|
||||
return np.pad(
|
||||
img,
|
||||
((0, 0), (0, out_height - height), (0, out_width - width)),
|
||||
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||
mode="symmetric",
|
||||
)
|
||||
|
||||
@ -88,15 +108,13 @@ def pad_img_to_modulo(img, mod):
|
||||
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
mask: (1, h, w) 0~1
|
||||
mask: (h, w, 1) 0~255
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
height, width = mask.shape[1:]
|
||||
_, thresh = cv2.threshold(
|
||||
(mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0
|
||||
)
|
||||
height, width = mask.shape[:2]
|
||||
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
boxes = []
|
||||
|
@ -1,121 +0,0 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
|
||||
|
||||
class LaMa:
|
||||
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
crop_trigger_size: h, w
|
||||
crop_margin:
|
||||
device:
|
||||
"""
|
||||
self.crop_trigger_size = crop_trigger_size
|
||||
self.crop_margin = crop_margin
|
||||
self.device = device
|
||||
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"lama torchscript model not found: {model_path}"
|
||||
)
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
|
||||
print(f"Load LaMa model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
area = image.shape[1] * image.shape[2]
|
||||
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
|
||||
return self._run(image, mask)
|
||||
|
||||
print("Trigger crop image")
|
||||
boxes = boxes_from_mask(mask)
|
||||
crop_result = []
|
||||
for box in boxes:
|
||||
crop_image, crop_box = self._run_box(image, mask, box)
|
||||
crop_result.append((crop_image, crop_box))
|
||||
|
||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
|
||||
for crop_image, crop_box in crop_result:
|
||||
x1, y1, x2, y2 = crop_box
|
||||
image[y1:y2, x1:x2, :] = crop_image
|
||||
return image
|
||||
|
||||
def _run_box(self, image, mask, box):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
box: [left,top,right,bottom]
|
||||
|
||||
Returns:
|
||||
BGR IMAGE
|
||||
"""
|
||||
box_h = box[3] - box[1]
|
||||
box_w = box[2] - box[0]
|
||||
cx = (box[0] + box[2]) // 2
|
||||
cy = (box[1] + box[3]) // 2
|
||||
img_h, img_w = image.shape[1:]
|
||||
|
||||
w = box_w + self.crop_margin * 2
|
||||
h = box_h + self.crop_margin * 2
|
||||
|
||||
l = max(cx - w // 2, 0)
|
||||
t = max(cy - h // 2, 0)
|
||||
r = min(cx + w // 2, img_w)
|
||||
b = min(cy + h // 2, img_h)
|
||||
|
||||
crop_img = image[:, t:b, l:r]
|
||||
crop_mask = mask[:, t:b, l:r]
|
||||
|
||||
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||
|
||||
return self._run(crop_img, crop_mask), [l, t, r, b]
|
||||
|
||||
def _run(self, image, mask):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
device = self.device
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=8)
|
||||
mask = pad_img_to_modulo(mask, mod=8)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||
return cur_res
|
0
lama_cleaner/model/__init__.py
Normal file
127
lama_cleaner/model/base.py
Normal file
@ -0,0 +1,127 @@
|
||||
import abc
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
pad_mod = 8
|
||||
|
||||
def __init__(self, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
self.device = device
|
||||
self.init_model(device)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device):
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask, config: Config):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
|
||||
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
|
||||
result = self.forward(padd_image, padd_mask, config)
|
||||
result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
original_pixel_indices = mask != 255
|
||||
result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [H, W, C] RGB, not normalized
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
inpaint_result = None
|
||||
logger.info(f"hd_strategy: {config.hd_strategy}")
|
||||
if config.hd_strategy == HDStrategy.CROP:
|
||||
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
||||
logger.info(f"Run crop strategy")
|
||||
boxes = boxes_from_mask(mask)
|
||||
crop_result = []
|
||||
for box in boxes:
|
||||
crop_image, crop_box = self._run_box(image, mask, box, config)
|
||||
crop_result.append((crop_image, crop_box))
|
||||
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
for crop_image, crop_box in crop_result:
|
||||
x1, y1, x2, y2 = crop_box
|
||||
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
||||
|
||||
elif config.hd_strategy == HDStrategy.RESIZE:
|
||||
if max(image.shape) > config.hd_strategy_resize_limit:
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
|
||||
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
|
||||
|
||||
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
|
||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
original_pixel_indices = mask != 255
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
|
||||
if inpaint_result is None:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def _run_box(self, image, mask, box, config: Config):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
box: [left,top,right,bottom]
|
||||
|
||||
Returns:
|
||||
BGR IMAGE
|
||||
"""
|
||||
box_h = box[3] - box[1]
|
||||
box_w = box[2] - box[0]
|
||||
cx = (box[0] + box[2]) // 2
|
||||
cy = (box[1] + box[3]) // 2
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
w = box_w + config.hd_strategy_crop_margin * 2
|
||||
h = box_h + config.hd_strategy_crop_margin * 2
|
||||
|
||||
l = max(cx - w // 2, 0)
|
||||
t = max(cy - h // 2, 0)
|
||||
r = min(cx + w // 2, img_w)
|
||||
b = min(cy + h // 2, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
|
||||
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
||||
|
||||
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
68
lama_cleaner/model/lama.py
Normal file
@ -0,0 +1,68 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
pad_mod = 8
|
||||
|
||||
def __init__(self, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device):
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"lama torchscript model not found: {model_path}"
|
||||
)
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
logger.info(f"Load LaMa model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.model_path = model_path
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||
return cur_res
|
@ -2,13 +2,16 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from lama_cleaner.helper import pad_img_to_modulo, download_model
|
||||
from lama_cleaner.ldm.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
timestep_embedding
|
||||
|
||||
LDM_ENCODE_MODEL_URL = os.environ.get(
|
||||
@ -217,7 +220,7 @@ class DDIMSampler(object):
|
||||
|
||||
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
@ -263,102 +266,54 @@ class DDIMSampler(object):
|
||||
|
||||
def load_jit_model(url, device):
|
||||
model_path = download_model(url)
|
||||
logger.info(f"Load LDM model from: {model_path}")
|
||||
model = torch.jit.load(model_path).to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LDM:
|
||||
def __init__(self, device, steps=50):
|
||||
class LDM(InpaintModel):
|
||||
pad_mod = 32
|
||||
|
||||
def __init__(self, device):
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device):
|
||||
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
||||
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
||||
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
||||
|
||||
model = LatentDiffusion(self.diffusion_model, device)
|
||||
self.sampler = DDIMSampler(model)
|
||||
self.steps = steps
|
||||
|
||||
def _norm(self, tensor):
|
||||
return tensor * 2.0 - 1.0
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
model_paths = [
|
||||
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
||||
]
|
||||
return all([os.path.exists(it) for it in model_paths])
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [C, H, W] RGB
|
||||
mask: [1, H, W]
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# image [1,3,512,512] float32
|
||||
# mask: [1,1,512,512] float32
|
||||
# masked_image: [1,3,512,512] float32
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=32)
|
||||
mask = pad_img_to_modulo(mask, mod=32)
|
||||
padded_height, padded_width = image.shape[1:]
|
||||
steps = config.ldm_steps
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# crop 512 x 512
|
||||
if padded_width <= 512 or padded_height <= 512:
|
||||
np_img = self._forward(image, mask, self.device)
|
||||
else:
|
||||
print("Try to zoom in")
|
||||
# zoom in
|
||||
# x,y,w,h
|
||||
# box = self.box_from_bitmap(mask)
|
||||
box = self.find_main_content(mask)
|
||||
if box is None:
|
||||
print("No bbox found")
|
||||
np_img = self._forward(image, mask, self.device)
|
||||
else:
|
||||
print(f"box: {box}")
|
||||
box_x, box_y, box_w, box_h = box
|
||||
cx = box_x + box_w // 2
|
||||
cy = box_y + box_h // 2
|
||||
|
||||
# w = max(512, box_w)
|
||||
# h = max(512, box_h)
|
||||
w = box_w + 512
|
||||
h = box_h + 512
|
||||
|
||||
left = max(cx - w // 2, 0)
|
||||
top = max(cy - h // 2, 0)
|
||||
right = min(cx + w // 2, origin_width)
|
||||
bottom = min(cy + h // 2, origin_height)
|
||||
|
||||
x = left
|
||||
y = top
|
||||
w = right - left
|
||||
h = bottom - top
|
||||
|
||||
crop_img = image[:, int(y):int(y + h), int(x):int(x + w)]
|
||||
crop_mask = mask[:, int(y):int(y + h), int(x):int(x + w)]
|
||||
|
||||
print(f"Apply zoom in size width x height: {crop_img.shape}")
|
||||
|
||||
crop_img_height, crop_img_width = crop_img.shape[1:]
|
||||
|
||||
crop_img = pad_img_to_modulo(crop_img, mod=32)
|
||||
crop_mask = pad_img_to_modulo(crop_mask, mod=32)
|
||||
# RGB
|
||||
np_img = self._forward(crop_img, crop_mask, self.device)
|
||||
|
||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
image[int(y): int(y + h), int(x): int(x + w), :] = np_img[0:crop_img_height, 0:crop_img_width, :]
|
||||
np_img = image
|
||||
# BGR to RGB
|
||||
# np_img = image[:, :, ::-1]
|
||||
|
||||
np_img = np_img[0:origin_height, 0:origin_width, :]
|
||||
np_img = np_img[:, :, ::-1]
|
||||
|
||||
return np_img
|
||||
|
||||
def _forward(self, image, mask, device):
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
masked_image = (1 - mask) * image
|
||||
|
||||
image = self._norm(image)
|
||||
@ -371,47 +326,20 @@ class LDM:
|
||||
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
||||
|
||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||
samples_ddim = self.sampler.sample(steps=self.steps,
|
||||
samples_ddim = self.sampler.sample(steps=steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape)
|
||||
x_samples_ddim = self.cond_stage_model_decode(samples_ddim) # samples_ddim: 1, 3, 128, 128 float32
|
||||
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
np_img = inpainted.astype(np.uint8)
|
||||
return np_img
|
||||
# inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
||||
return inpainted_image
|
||||
|
||||
def find_main_content(self, bitmap: np.ndarray):
|
||||
th2 = bitmap[0].astype(np.uint8)
|
||||
row_sum = th2.sum(1)
|
||||
col_sum = th2.sum(0)
|
||||
xmin = max(0, np.argwhere(col_sum != 0).min() - 20)
|
||||
xmax = min(np.argwhere(col_sum != 0).max() + 20, th2.shape[1])
|
||||
ymin = max(0, np.argwhere(row_sum != 0).min() - 20)
|
||||
ymax = min(np.argwhere(row_sum != 0).max() + 20, th2.shape[0])
|
||||
|
||||
left, top, right, bottom = int(xmin), int(ymin), int(xmax), int(ymax)
|
||||
return left, top, right - left, bottom - top
|
||||
|
||||
def box_from_bitmap(self, bitmap):
|
||||
"""
|
||||
bitmap: single map with shape (NUM_CLASSES, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
contours, _ = cv2.findContours(
|
||||
(bitmap[0] * 255).astype(np.uint8), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_NONE
|
||||
)
|
||||
|
||||
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
|
||||
num_contours = len(contours)
|
||||
print(f"contours size: {num_contours}")
|
||||
if num_contours != 1:
|
||||
return None
|
||||
|
||||
# x,y,w,h
|
||||
return cv2.boundingRect(contours[0])
|
||||
def _norm(self, tensor):
|
||||
return tensor * 2.0 - 1.0
|
42
lama_cleaner/model_manager.py
Normal file
@ -0,0 +1,42 @@
|
||||
from lama_cleaner.model.lama import LaMa
|
||||
from lama_cleaner.model.ldm import LDM
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class ModelManager:
|
||||
LAMA = 'lama'
|
||||
LDM = 'ldm'
|
||||
|
||||
def __init__(self, name: str, device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.model = self.init_model(name, device)
|
||||
|
||||
def init_model(self, name: str, device):
|
||||
if name == self.LAMA:
|
||||
model = LaMa(device)
|
||||
elif name == self.LDM:
|
||||
model = LDM(device)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
return model
|
||||
|
||||
def is_downloaded(self, name: str) -> bool:
|
||||
if name == self.LAMA:
|
||||
return LaMa.is_downloaded()
|
||||
elif name == self.LDM:
|
||||
return LDM.is_downloaded()
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str):
|
||||
if new_name == self.name:
|
||||
return
|
||||
try:
|
||||
self.model = self.init_model(new_name, self.device)
|
||||
self.name = new_name
|
||||
except NotImplementedError as e:
|
||||
raise e
|
32
lama_cleaner/parse_args.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
import imghdr
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
"--gui-size",
|
||||
default=[1600, 1000],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="Set window size for GUI",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="Path to image you want to load by default"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
parser.error(f"invalid --input: {args.input} not exists")
|
||||
if imghdr.what(args.input) is None:
|
||||
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||
|
||||
return args
|
17
lama_cleaner/schema.py
Normal file
@ -0,0 +1,17 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HDStrategy(str, Enum):
|
||||
ORIGINAL = 'Original'
|
||||
RESIZE = 'Resize'
|
||||
CROP = 'Crop'
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
ldm_steps: int
|
||||
hd_strategy: str
|
||||
hd_strategy_crop_margin: int
|
||||
hd_strategy_crop_trigger_size: int
|
||||
hd_strategy_resize_limit: int
|
189
lama_cleaner/server.py
Normal file
@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import imghdr
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
from flask import Flask, request, send_file, cli
|
||||
# Disable ability for Flask to display warning about using a development server in a production environment.
|
||||
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
||||
cli.show_server_banner = lambda *_: None
|
||||
from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
if os.environ.get("CACHE_DIR"):
|
||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
||||
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
||||
|
||||
|
||||
class NoFlaskwebgui(logging.Filter):
|
||||
def filter(self, record):
|
||||
return 'GET //flaskwebgui-keep-server-alive' not in record.getMessage()
|
||||
|
||||
|
||||
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
|
||||
model: ModelManager = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
w = imghdr.what("", img_bytes)
|
||||
if w is None:
|
||||
w = "jpeg"
|
||||
return w
|
||||
|
||||
|
||||
@app.route("/inpaint", methods=["POST"])
|
||||
def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
origin_image_bytes = input["image"].read()
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
form = request.form
|
||||
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
|
||||
if size_limit == "Original":
|
||||
size_limit = max(image.shape)
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
config = Config(
|
||||
ldm_steps=form['ldmSteps'],
|
||||
hd_strategy=form['hdStrategy'],
|
||||
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
|
||||
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
|
||||
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
|
||||
)
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
|
||||
start = time.time()
|
||||
res_np_img = model(image, mask, config)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||
alpha_channel = cv2.resize(
|
||||
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
|
||||
)
|
||||
res_np_img = np.concatenate(
|
||||
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||
)
|
||||
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
return send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||
mimetype=f"image/{ext}",
|
||||
)
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
|
||||
|
||||
@app.route("/model_downloaded/<name>")
|
||||
def model_downloaded(name):
|
||||
return str(model.is_downloaded(name)), 200
|
||||
|
||||
|
||||
@app.route("/model", methods=["POST"])
|
||||
def switch_model():
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
except NotImplementedError:
|
||||
return f"{new_name} not implemented", 403
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.route("/inputimage")
|
||||
def set_input_photo():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global device
|
||||
global input_image_path
|
||||
|
||||
device = torch.device(args.device)
|
||||
input_image_path = args.input
|
||||
|
||||
model = ModelManager(name=args.model, device=device)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
ui = FlaskUI(app, width=app_width, height=app_height, host=args.host, port=args.port)
|
||||
ui.run()
|
||||
else:
|
||||
app.run(host=args.host, port=args.port, debug=args.debug)
|
0
lama_cleaner/tests/__init__.py
Normal file
BIN
lama_cleaner/tests/image.png
Normal file
After Width: | Height: | Size: 129 KiB |
BIN
lama_cleaner/tests/lama_crop_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/lama_original_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/lama_resize_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_crop_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_original_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
BIN
lama_cleaner/tests/ldm_resize_result.png
Normal file
After Width: | Height: | Size: 193 KiB |
Before Width: | Height: | Size: 11 KiB |
BIN
lama_cleaner/tests/mask.png
Normal file
After Width: | Height: | Size: 7.7 KiB |
@ -1,15 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.helper import boxes_from_mask
|
||||
|
||||
|
||||
def test_boxes_from_mask():
|
||||
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
|
||||
mask = mask[:, :, np.newaxis]
|
||||
mask = (mask / 255).transpose(2, 0, 1)
|
||||
boxes = boxes_from_mask(mask)
|
||||
print(boxes)
|
||||
|
||||
|
||||
test_boxes_from_mask()
|
55
lama_cleaner/tests/test_model.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
|
||||
|
||||
def get_data():
|
||||
img = cv2.imread(str(current_dir / 'image.png'))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
|
||||
return img, mask
|
||||
|
||||
|
||||
def get_config(strategy):
|
||||
return Config(
|
||||
ldm_steps=1,
|
||||
hd_strategy=strategy,
|
||||
hd_strategy_crop_margin=32,
|
||||
hd_strategy_crop_trigger_size=200,
|
||||
hd_strategy_resize_limit=200,
|
||||
)
|
||||
|
||||
|
||||
def assert_equal(model, config, gt_name):
|
||||
img, mask = get_data()
|
||||
res = model(img, mask, config)
|
||||
# cv2.imwrite(gt_name, res,
|
||||
# [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
|
||||
|
||||
"""
|
||||
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
||||
when the saved image is reloaded, a difference occurs with the original pixel value.
|
||||
If you want to save the original image as it is, save it as PNG or BMP.
|
||||
"""
|
||||
gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
|
||||
assert np.array_equal(res, gt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
def test_lama(strategy):
|
||||
model = ModelManager(name='lama', device='cpu')
|
||||
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
def test_ldm(strategy):
|
||||
model = ModelManager(name='ldm', device='cpu')
|
||||
assert_equal(model, get_config(strategy), f'ldm_{strategy[0].upper() + strategy[1:]}_result.png')
|
210
main.py
@ -1,210 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import imghdr
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from lama_cleaner.lama import LaMa
|
||||
from lama_cleaner.ldm import LDM
|
||||
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
from flask import Flask, request, send_file
|
||||
from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
norm_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
if os.environ.get("CACHE_DIR"):
|
||||
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
||||
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app)
|
||||
|
||||
model = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
w = imghdr.what("", img_bytes)
|
||||
if w is None:
|
||||
w = "jpeg"
|
||||
return w
|
||||
|
||||
|
||||
@app.route("/inpaint", methods=["POST"])
|
||||
def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
origin_image_bytes = input["image"].read()
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
||||
if size_limit == "Original":
|
||||
size_limit = max(image.shape)
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
print(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
print(f"Resized image shape: {image.shape}")
|
||||
image = norm_img(image)
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
mask = norm_img(mask)
|
||||
|
||||
start = time.time()
|
||||
res_np_img = model(image, mask)
|
||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||
alpha_channel = cv2.resize(
|
||||
alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0])
|
||||
)
|
||||
res_np_img = np.concatenate(
|
||||
(res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||
)
|
||||
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
return send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||
mimetype=f"image/{ext}",
|
||||
)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
||||
|
||||
@app.route("/inputimage")
|
||||
def set_input_photo():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
io.BytesIO(image_in_bytes),
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="Path to image you want to load by default"
|
||||
)
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument(
|
||||
"--crop-trigger-size",
|
||||
default=[2042, 2042],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="If image size large then crop-trigger-size, "
|
||||
"crop each area from original image to do inference."
|
||||
"Mainly for performance and memory reasons"
|
||||
"Only for lama",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-margin",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Margin around bounding box of painted stroke when crop mode triggered",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ldm-steps",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Steps for DDIM sampling process."
|
||||
"The larger the value, the better the result, but it will be more time-consuming",
|
||||
)
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
"--gui-size",
|
||||
default=[1600, 1000],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="Set window size for GUI",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
parser.error(f"invalid --input: {args.input} not exists")
|
||||
if imghdr.what(args.input) is None:
|
||||
parser.error(f"invalid --input: {args.input} is not a valid image file")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
global model
|
||||
global device
|
||||
global input_image_path
|
||||
|
||||
args = get_args_parser()
|
||||
device = torch.device(args.device)
|
||||
input_image_path = args.input
|
||||
|
||||
if args.model == "lama":
|
||||
model = LaMa(
|
||||
crop_trigger_size=args.crop_trigger_size,
|
||||
crop_margin=args.crop_margin,
|
||||
device=device,
|
||||
)
|
||||
elif args.model == "ldm":
|
||||
model = LDM(device, steps=args.ldm_steps)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {args.model}")
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
ui = FlaskUI(app, width=app_width, height=app_height)
|
||||
ui.run()
|
||||
else:
|
||||
app.run(host="127.0.0.1", port=args.port, debug=args.debug)
|
||||
|
||||
from lama_cleaner import entry_point
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
entry_point()
|
||||
|
10
publish.sh
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
pushd ./lama_cleaner/app
|
||||
yarn run build
|
||||
popd
|
||||
|
||||
rm -r -f dist
|
||||
python3 setup.py sdist bdist_wheel
|
||||
twine upload dist/*
|
2
requirements-dev.txt
Normal file
@ -0,0 +1,2 @@
|
||||
wheel
|
||||
twine
|
@ -1,6 +1,9 @@
|
||||
torch>=1.8.2
|
||||
opencv-python
|
||||
flask_cors
|
||||
flask
|
||||
flask>=2.1.1
|
||||
flaskwebgui
|
||||
tqdm
|
||||
pydantic
|
||||
loguru
|
||||
pytest
|
||||
|
42
setup.py
@ -1 +1,41 @@
|
||||
# TODO: make this a python package
|
||||
import setuptools
|
||||
from pathlib import Path
|
||||
|
||||
web_files = Path("lama_cleaner/app/build/").glob("**/*")
|
||||
web_files = [str(it).replace("lama_cleaner/", "") for it in web_files]
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
def load_requirements():
|
||||
requirements_file_name = "requirements.txt"
|
||||
requires = []
|
||||
with open(requirements_file_name) as f:
|
||||
for line in f:
|
||||
if line:
|
||||
requires.append(line.strip())
|
||||
return requires
|
||||
|
||||
|
||||
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
||||
setuptools.setup(
|
||||
name="lama-cleaner",
|
||||
version="0.9.1",
|
||||
author="PanicByte",
|
||||
author_email="cwq1913@gmail.com",
|
||||
description="Image inpainting tool powered by SOTA AI Model",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/Sanster/lama-cleaner",
|
||||
packages=setuptools.find_packages("./"),
|
||||
package_data={"lama_cleaner": web_files},
|
||||
install_requires=load_requirements(),
|
||||
python_requires=">=3.6",
|
||||
entry_points={"console_scripts": ["lama-cleaner=lama_cleaner:entry_point"]},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
)
|
||||
|