This commit is contained in:
Qing 2022-09-15 22:21:27 +08:00
parent 3ac6ee7f44
commit 32854d40da
52 changed files with 2258 additions and 205 deletions

View File

@ -28,6 +28,7 @@
1. [ZITS](https://github.com/DQiaole/ZITS_inpainting)
1. [MAT](https://github.com/fenglinglwb/MAT)
1. [FcF](https://github.com/SHI-Labs/FcF-Inpainting)
1. [SD](https://github.com/CompVis/stable-diffusion)
- Support CPU & GPU
- Various inpainting [strategy](#inpainting-strategy)
- Run as a desktop APP
@ -54,15 +55,16 @@ lama-cleaner --model=lama --device=cpu --port=8080
Available arguments:
| Name | Description | Default |
| ---------- | ---------------------------------------------------------------- | -------- |
| --model | lama/ldm/zits. See details in [Inpaint Model](#inpainting-model) | lama |
| --device | cuda or cpu | cuda |
| --port | Port for backend flask web server | 8080 |
| --gui | Launch lama-cleaner as a desktop application | |
| --gui_size | Set the window size for the application | 1200 900 |
| --input | Path to image you want to load by default | None |
| --debug | Enable debug mode for flask web server | |
| Name | Description | Default |
| ----------------- | -------------------------------------------------------------------------------------------------------- | -------- |
| --model | lama/ldm/zits/mat/fcf/sd. See details in [Inpaint Model](#inpainting-model) | lama |
| --hf_access_token | stable-diffusion(sd) model need huggingface access token https://huggingface.co/docs/hub/security-tokens | |
| --device | cuda or cpu | cuda |
| --port | Port for backend flask web server | 8080 |
| --gui | Launch lama-cleaner as a desktop application | |
| --gui_size | Set the window size for the application | 1200 900 |
| --input | Path to image you want to load by default | None |
| --debug | Enable debug mode for flask web server | |
## Inpainting Model
@ -73,6 +75,7 @@ Available arguments:
| ZITS | :+1: Better holistic structures compared with previous methods <br/> :neutral_face: Wireframe module is **very** slow on CPU | `Wireframe`: Enable edge and line detect |
| MAT | TODO | |
| FcF | :+1: Better structure and texture generation <br/> :neutral_face: Only support fixed size (512x512) input | |
| SD | :+1: SOTA text-to-image diffusion model | |
### LaMa vs LDM

View File

@ -6,6 +6,7 @@
"dependencies": {
"@heroicons/react": "^1.0.4",
"@radix-ui/react-dialog": "0.1.8-rc.25",
"@radix-ui/react-popover": "^1.0.0",
"@radix-ui/react-select": "0.1.2-rc.27",
"@radix-ui/react-switch": "^0.1.5",
"@radix-ui/react-toast": "^0.1.1",
@ -20,14 +21,17 @@
"@types/react-dom": "^17.0.9",
"cross-env": "7.x",
"lodash": "^4.17.21",
"mitt": "^3.0.0",
"nanoid": "^4.0.0",
"npm-run-all": "4.x",
"react": "^17.0.2",
"react-dom": "^17.0.2",
"react-hotkeys-hook": "^3.4.7",
"react-scripts": "4.0.3",
"react-use": "^17.3.1",
"react-zoom-pan-pinch": "^2.1.3",
"recoil": "^0.6.1",
"socket.io-client": "^4.5.2",
"typescript": "4.x"
},
"scripts": {

View File

@ -1,5 +1,4 @@
import React, { useEffect, useMemo } from 'react'
import { useKeyPressEvent } from 'react-use'
import { useRecoilState } from 'recoil'
import { nanoid } from 'nanoid'
import useInputImage from './hooks/useInputImage'
@ -9,6 +8,7 @@ import Workspace from './components/Workspace'
import { fileState } from './store/Atoms'
import { keepGUIAlive } from './utils'
import Header from './components/Header/Header'
import useHotKey from './hooks/useHotkey'
// Keeping GUI Window Open
keepGUIAlive()
@ -24,11 +24,15 @@ function App() {
}, [userInputImage, setFile])
// Dark Mode Hotkey
useKeyPressEvent('D', ev => {
ev?.preventDefault()
const newTheme = theme === 'light' ? 'dark' : 'light'
setTheme(newTheme)
})
useHotKey(
'shift+d',
() => {
const newTheme = theme === 'light' ? 'dark' : 'light'
setTheme(newTheme)
},
{},
[theme]
)
useEffect(() => {
document.body.setAttribute('data-theme', theme)

View File

@ -1,4 +1,4 @@
import { Settings } from '../store/Atoms'
import { Rect, Settings } from '../store/Atoms'
import { dataURItoBlob } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -7,6 +7,8 @@ export default async function inpaint(
imageFile: File,
maskBase64: string,
settings: Settings,
croperRect: Rect,
prompt?: string,
sizeLimit?: string
) {
// 1080, 2000, Original
@ -30,6 +32,18 @@ export default async function inpaint(
hdSettings.hdStrategyResizeLimit.toString()
)
fd.append('prompt', prompt === undefined ? '' : prompt)
fd.append('croperX', croperRect.x.toString())
fd.append('croperY', croperRect.y.toString())
fd.append('croperHeight', croperRect.height.toString())
fd.append('croperWidth', croperRect.width.toString())
fd.append('useCroper', settings.showCroper ? 'true' : 'false')
fd.append('sdStrength', settings.sdStrength.toString())
fd.append('sdSteps', settings.sdSteps.toString())
fd.append('sdGuidanceScale', settings.sdGuidanceScale.toString())
fd.append('sdSampler', settings.sdSampler.toString())
fd.append('sdSeed', settings.sdSeedFixed ? settings.sdSeed.toString() : '-1')
if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080')
} else {

View File

@ -0,0 +1,126 @@
@use 'sass:math';
$drag-handle-shortside: 12px;
$drag-handle-longside: 40px;
$half-handle-shortside: math.div($drag-handle-shortside, 2);
$half-handle-longside: math.div($drag-handle-longside, 2);
.crop-border {
outline-color: var(--yellow-accent);
outline-style: dashed;
}
.info-bar {
position: absolute;
pointer-events: auto;
font-size: 1rem;
padding: 0.2rem 0.8rem;
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
color: var(--text-color);
background-color: var(--page-bg);
border-radius: 9999px;
border: var(--editor-toolkit-panel-border);
box-shadow: 0 0 0 1px #0000001a, 0 3px 16px #00000014, 0 2px 6px 1px #00000017;
&:hover {
cursor: move;
}
}
.croper-wrapper {
position: absolute;
height: 100%;
width: 100%;
z-index: 2;
overflow: hidden;
pointer-events: none;
}
.croper {
position: relative;
top: 0;
bottom: 0;
left: 0;
right: 0;
z-index: 2;
pointer-events: none;
display: flex;
flex-direction: column;
align-items: center;
box-shadow: 0 0 0 9999px rgba(0, 0, 0, 0.5);
}
.drag-handle {
width: $drag-handle-shortside;
height: $drag-handle-shortside;
z-index: 4;
position: absolute;
display: block;
content: '';
border: 2px solid var(--yellow-accent);
background-color: var(--yellow-accent-light);
pointer-events: auto;
&:hover {
background-color: var(--yellow-accent);
}
}
.ord-topleft {
cursor: nw-resize;
top: (-$half-handle-shortside)-1px;
left: (-$half-handle-shortside)-1px;
}
.ord-topright {
cursor: ne-resize;
top: -($half-handle-shortside)-1px;
right: -($half-handle-shortside)-1px;
}
.ord-bottomright {
cursor: se-resize;
bottom: -($half-handle-shortside)-1px;
right: -($half-handle-shortside)-1px;
}
.ord-bottomleft {
cursor: sw-resize;
bottom: -($half-handle-shortside)-1px;
left: -($half-handle-shortside)-1px;
}
.ord-top,
.ord-bottom {
left: calc(50% - $half-handle-shortside);
cursor: ns-resize;
}
.ord-top {
top: (-$half-handle-shortside)-1px;
}
.ord-bottom {
bottom: -($half-handle-shortside)-1px;
}
.ord-left,
.ord-right {
top: calc(50% - $half-handle-shortside);
cursor: ew-resize;
}
.ord-left {
left: (-$half-handle-shortside)-1px;
}
.ord-right {
right: -($half-handle-shortside)-1px;
}

View File

@ -0,0 +1,326 @@
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/outline'
import React, { useEffect, useState } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil'
import {
croperHeight,
croperWidth,
croperX,
croperY,
isInpaintingState,
} from '../../store/Atoms'
const DOC_MOVE_OPTS = { capture: true, passive: false }
const DRAG_HANDLE_BORDER = 2
const DRAG_HANDLE_SHORT = 12
const DRAG_HANDLE_LONG = 40
interface EVData {
initX: number
initY: number
initHeight: number
initWidth: number
startResizeX: number
startResizeY: number
ord: string // top/right/bottom/left
}
interface Props {
maxHeight: number
maxWidth: number
scale: number
minHeight: number
minWidth: number
}
const Croper = (props: Props) => {
const { minHeight, minWidth, maxHeight, maxWidth, scale } = props
const [x, setX] = useRecoilState(croperX)
const [y, setY] = useRecoilState(croperY)
const [height, setHeight] = useRecoilState(croperHeight)
const [width, setWidth] = useRecoilState(croperWidth)
const isInpainting = useRecoilValue(isInpaintingState)
const [isResizing, setIsResizing] = useState(false)
const [isMoving, setIsMoving] = useState(false)
useEffect(() => {
setX(Math.round((maxWidth - 512) / 2))
setY(Math.round((maxHeight - 512) / 2))
}, [maxHeight, maxWidth, minHeight, minWidth])
const [evData, setEVData] = useState<EVData>({
initX: 0,
initY: 0,
initHeight: 0,
initWidth: 0,
startResizeX: 0,
startResizeY: 0,
ord: 'top',
})
const onDragFocus = () => {
console.log('focus')
}
const checkTopBottomLimit = (newY: number, newHeight: number) => {
if (newY > 0 && newHeight > minHeight && newY + newHeight <= maxHeight) {
return true
}
return false
}
const checkLeftRightLimit = (newX: number, newWidth: number) => {
if (newX > 0 && newWidth > minWidth && newX + newWidth <= maxWidth) {
return true
}
return false
}
const onPointerMove = (e: PointerEvent) => {
if (isInpainting) {
return
}
const curX = e.clientX
const curY = e.clientY
if (isResizing) {
switch (evData.ord) {
case 'top': {
// TODO: 添加四个角以及 drag bar handle
const offset = Math.round((curY - evData.startResizeY) / scale)
const newHeight = evData.initHeight - offset
const newY = evData.initY + offset
if (checkTopBottomLimit(newY, newHeight)) {
setHeight(newHeight)
setY(newY)
}
break
}
case 'right': {
const offset = Math.round((curX - evData.startResizeX) / scale)
const newWidth = evData.initWidth + offset
if (checkLeftRightLimit(evData.initX, newWidth)) {
setWidth(newWidth)
}
break
}
case 'bottom': {
const offset = Math.round((curY - evData.startResizeY) / scale)
const newHeight = evData.initHeight + offset
if (checkTopBottomLimit(evData.initY, newHeight)) {
setHeight(newHeight)
}
break
}
case 'left': {
const offset = Math.round((curX - evData.startResizeX) / scale)
const newWidth = evData.initWidth - offset
const newX = evData.initX + offset
if (checkLeftRightLimit(newX, newWidth)) {
setWidth(newWidth)
setX(newX)
}
break
}
default:
break
}
}
if (isMoving) {
const offsetX = Math.round((curX - evData.startResizeX) / scale)
const offsetY = Math.round((curY - evData.startResizeY) / scale)
const newX = evData.initX + offsetX
const newY = evData.initY + offsetY
if (
checkLeftRightLimit(newX, evData.initWidth) &&
checkTopBottomLimit(newY, evData.initHeight)
) {
setX(newX)
setY(newY)
}
}
}
const onPointerDone = (e: PointerEvent) => {
if (isResizing) {
setIsResizing(false)
}
if (isMoving) {
setIsMoving(false)
}
}
useEffect(() => {
if (isResizing || isMoving) {
document.addEventListener('pointermove', onPointerMove, DOC_MOVE_OPTS)
document.addEventListener('pointerup', onPointerDone, DOC_MOVE_OPTS)
document.addEventListener('pointercancel', onPointerDone, DOC_MOVE_OPTS)
return () => {
document.removeEventListener(
'pointermove',
onPointerMove,
DOC_MOVE_OPTS
)
document.removeEventListener('pointerup', onPointerDone, DOC_MOVE_OPTS)
document.removeEventListener(
'pointercancel',
onPointerDone,
DOC_MOVE_OPTS
)
}
}
}, [isResizing, isMoving, width, height, evData])
const onCropPointerDown = (e: React.PointerEvent<HTMLDivElement>) => {
const { ord } = (e.target as HTMLElement).dataset
if (ord) {
setIsResizing(true)
setEVData({
initX: x,
initY: y,
initHeight: height,
initWidth: width,
startResizeX: e.clientX,
startResizeY: e.clientY,
ord,
})
}
}
const createCropSelection = () => {
return (
<div
className="drag-elements"
onFocus={onDragFocus}
onPointerDown={onCropPointerDown}
>
<div
className="drag-handle ord-topleft"
data-ord="topleft"
aria-label="topleft"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-topright"
data-ord="topright"
aria-label="topright"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-bottomleft"
data-ord="bottomleft"
aria-label="bottomleft"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-bottomright"
data-ord="bottomright"
aria-label="bottomright"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-top"
data-ord="top"
aria-label="top"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-right"
data-ord="right"
aria-label="right"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-bottom"
data-ord="bottom"
aria-label="bottom"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
<div
className="drag-handle ord-left"
data-ord="left"
aria-label="left"
tabIndex={0}
role="button"
style={{ transform: `scale(${1 / scale})` }}
/>
</div>
)
}
const onInfoBarPointerDown = (e: React.PointerEvent<HTMLDivElement>) => {
setIsMoving(true)
setEVData({
initX: x,
initY: y,
initHeight: height,
initWidth: width,
startResizeX: e.clientX,
startResizeY: e.clientY,
ord: '',
})
}
const createInfoBar = () => {
return (
<div
className="info-bar"
onPointerDown={onInfoBarPointerDown}
style={{
transform: `scale(${1 / scale})`,
top: `${-28 / scale - 14}px`,
}}
>
<div className="crop-size">
{width} x {height}
</div>
</div>
)
}
const createBorder = () => {
return (
<div
className="crop-border"
style={{
height,
width,
outlineWidth: `${DRAG_HANDLE_BORDER / scale}px`,
}}
/>
)
}
return (
<div className="croper-wrapper">
<div className="croper" style={{ height, width, left: x, top: y }}>
{createBorder()}
{createInfoBar()}
{createCropSelection()}
</div>
</div>
)
}
export default Croper

View File

@ -55,7 +55,7 @@
position: fixed;
bottom: 0.5rem;
border-radius: 3rem;
padding: 1rem 3rem;
padding: 0.6rem 3rem;
display: grid;
grid-template-areas: 'toolkit-size-selector toolkit-brush-slider toolkit-btns';
column-gap: 2rem;

View File

@ -22,7 +22,6 @@ import Button from '../shared/Button'
import Slider from './Slider'
import SizeSelector from './SizeSelector'
import {
dataURItoBlob,
downloadImage,
isMidClick,
isRightClick,
@ -30,7 +29,18 @@ import {
srcToFile,
useImage,
} from '../../utils'
import { settingState, toastState } from '../../store/Atoms'
import {
croperState,
isInpaintingState,
isSDState,
propmtState,
runManuallyState,
settingState,
toastState,
} from '../../store/Atoms'
import useHotKey from '../../hooks/useHotkey'
import Croper from '../Croper/Croper'
import emitter, { EVENT_PROMPT } from '../../event'
const TOOLBAR_SIZE = 200
const BRUSH_COLOR = '#ffcc00bb'
@ -74,8 +84,14 @@ function mouseXY(ev: SyntheticEvent) {
export default function Editor(props: EditorProps) {
const { file } = props
const promptVal = useRecoilValue(propmtState)
const settings = useRecoilValue(settingState)
const croperRect = useRecoilValue(croperState)
const [toastVal, setToastState] = useRecoilState(toastState)
const [isInpainting, setIsInpainting] = useRecoilState(isInpaintingState)
const runMannually = useRecoilValue(runManuallyState)
const isSD = useRecoilValue(isSDState)
const [brushSize, setBrushSize] = useState(40)
const [original, isOriginalLoaded] = useImage(file)
const [renders, setRenders] = useState<HTMLImageElement[]>([])
@ -90,7 +106,6 @@ export default function Editor(props: EditorProps) {
const [showRefBrush, setShowRefBrush] = useState(false)
const [isPanning, setIsPanning] = useState<boolean>(false)
const [showOriginal, setShowOriginal] = useState(false)
const [isInpaintingLoading, setIsInpaintingLoading] = useState(false)
const [scale, setScale] = useState<number>(1)
const [panned, setPanned] = useState<boolean>(false)
const [minScale, setMinScale] = useState<number>(1.0)
@ -130,83 +145,28 @@ export default function Editor(props: EditorProps) {
[context, original]
)
const drawLinesOnMask = (_lineGroups: LineGroup[]) => {
if (!context?.canvas.width || !context?.canvas.height) {
throw new Error('canvas has invalid size')
}
maskCanvas.width = context?.canvas.width
maskCanvas.height = context?.canvas.height
const ctx = maskCanvas.getContext('2d')
if (!ctx) {
throw new Error('could not retrieve mask canvas')
}
_lineGroups.forEach(lineGroup => {
drawLines(ctx, lineGroup, 'white')
})
}
const runInpainting = async () => {
if (!hadDrawSomething()) {
return
}
const newLineGroups = [...lineGroups, curLineGroup]
setCurLineGroup([])
setIsDraging(false)
setIsInpaintingLoading(true)
if (settings.graduallyInpainting) {
drawLinesOnMask([curLineGroup])
} else {
drawLinesOnMask(newLineGroups)
}
let targetFile = file
if (settings.graduallyInpainting === true && renders.length > 0) {
console.info('gradually inpainting on last result')
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(lastRender.currentSrc, file.name, file.type)
}
try {
const res = await inpaint(
targetFile,
maskCanvas.toDataURL(),
settings,
sizeLimit.toString()
)
if (!res) {
throw new Error('empty response')
const drawLinesOnMask = useCallback(
(_lineGroups: LineGroup[]) => {
if (!context?.canvas.width || !context?.canvas.height) {
throw new Error('canvas has invalid size')
}
maskCanvas.width = context?.canvas.width
maskCanvas.height = context?.canvas.height
const ctx = maskCanvas.getContext('2d')
if (!ctx) {
throw new Error('could not retrieve mask canvas')
}
const newRender = new Image()
await loadImage(newRender, res)
const newRenders = [...renders, newRender]
setRenders(newRenders)
draw(newRender, [])
// Only append new LineGroup after inpainting success
setLineGroups(newLineGroups)
// clear redo stack
resetRedoState()
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 2000,
_lineGroups.forEach(lineGroup => {
drawLines(ctx, lineGroup, 'white')
})
drawOnCurrentRender([])
}
setIsInpaintingLoading(false)
}
},
[context, maskCanvas]
)
const hadDrawSomething = () => {
const hadDrawSomething = useCallback(() => {
return curLineGroup.length !== 0
}
const hadRunInpainting = () => {
return renders.length !== 0
}
}, [curLineGroup])
const drawOnCurrentRender = useCallback(
(lineGroup: LineGroup) => {
@ -219,8 +179,107 @@ export default function Editor(props: EditorProps) {
[original, renders, draw]
)
const runInpainting = useCallback(
async (prompt?: string) => {
console.log('runInpainting')
if (!hadDrawSomething()) {
return
}
console.log(prompt)
const newLineGroups = [...lineGroups, curLineGroup]
setCurLineGroup([])
setIsDraging(false)
setIsInpainting(true)
if (settings.graduallyInpainting) {
drawLinesOnMask([curLineGroup])
} else {
drawLinesOnMask(newLineGroups)
}
let targetFile = file
if (settings.graduallyInpainting === true && renders.length > 0) {
console.info('gradually inpainting on last result')
const lastRender = renders[renders.length - 1]
targetFile = await srcToFile(
lastRender.currentSrc,
file.name,
file.type
)
}
try {
const res = await inpaint(
targetFile,
maskCanvas.toDataURL(),
settings,
croperRect,
prompt,
sizeLimit.toString()
)
if (!res) {
throw new Error('empty response')
}
const newRender = new Image()
await loadImage(newRender, res)
const newRenders = [...renders, newRender]
setRenders(newRenders)
draw(newRender, [])
// Only append new LineGroup after inpainting success
setLineGroups(newLineGroups)
// clear redo stack
resetRedoState()
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 4000,
})
drawOnCurrentRender([])
}
setIsInpainting(false)
},
[
lineGroups,
curLineGroup,
maskCanvas,
settings.graduallyInpainting,
settings,
croperRect,
sizeLimit,
promptVal,
drawOnCurrentRender,
hadDrawSomething,
drawLinesOnMask,
]
)
useEffect(() => {
emitter.on(EVENT_PROMPT, () => {
if (hadDrawSomething()) {
runInpainting(promptVal)
} else {
setToastState({
open: true,
desc: 'Please draw mask on picture',
state: 'error',
duration: 1500,
})
}
})
return () => {
emitter.off(EVENT_PROMPT)
}
}, [hadDrawSomething, runInpainting, prompt])
const hadRunInpainting = () => {
return renders.length !== 0
}
const handleMultiStrokeKeyDown = () => {
if (isInpaintingLoading) {
if (isInpainting) {
return
}
setIsMultiStrokeKeyPressed(true)
@ -230,13 +289,13 @@ export default function Editor(props: EditorProps) {
if (!isMultiStrokeKeyPressed) {
return
}
if (isInpaintingLoading) {
if (isInpainting) {
return
}
setIsMultiStrokeKeyPressed(false)
if (!settings.runInpaintingManually) {
if (!runMannually) {
runInpainting()
}
}
@ -246,7 +305,7 @@ export default function Editor(props: EditorProps) {
}
useKey(predicate, handleMultiStrokeKeyup, { event: 'keyup' }, [
isInpaintingLoading,
isInpainting,
isMultiStrokeKeyPressed,
hadDrawSomething,
])
@ -257,7 +316,7 @@ export default function Editor(props: EditorProps) {
{
event: 'keydown',
},
[isInpaintingLoading]
[isInpainting]
)
// Draw once the original image is loaded
@ -341,7 +400,7 @@ export default function Editor(props: EditorProps) {
}, [windowSize, resetZoom])
const handleEscPressed = () => {
if (isInpaintingLoading) {
if (isInpainting) {
return
}
if (isDraging || isMultiStrokeKeyPressed) {
@ -361,7 +420,7 @@ export default function Editor(props: EditorProps) {
},
[
isDraging,
isInpaintingLoading,
isInpainting,
isMultiStrokeKeyPressed,
resetZoom,
drawOnCurrentRender,
@ -404,7 +463,7 @@ export default function Editor(props: EditorProps) {
if (!canvas) {
return
}
if (isInpaintingLoading) {
if (isInpainting) {
return
}
if (!isDraging) {
@ -416,13 +475,29 @@ export default function Editor(props: EditorProps) {
return
}
if (settings.runInpaintingManually) {
if (runMannually) {
setIsDraging(false)
} else {
runInpainting()
}
}
const isOutsideCroper = (clickPnt: { x: number; y: number }) => {
if (clickPnt.x < croperRect.x) {
return true
}
if (clickPnt.y < croperRect.y) {
return true
}
if (clickPnt.x > croperRect.x + croperRect.width) {
return true
}
if (clickPnt.y > croperRect.y + croperRect.height) {
return true
}
return false
}
const onMouseDown = (ev: SyntheticEvent) => {
if (isPanning) {
return
@ -434,7 +509,7 @@ export default function Editor(props: EditorProps) {
if (!canvas) {
return
}
if (isInpaintingLoading) {
if (isInpainting) {
return
}
@ -447,10 +522,14 @@ export default function Editor(props: EditorProps) {
return
}
if (isSD && settings.showCroper && isOutsideCroper(mouseXY(ev))) {
return
}
setIsDraging(true)
let lineGroup: LineGroup = []
if (isMultiStrokeKeyPressed || settings.runInpaintingManually) {
if (isMultiStrokeKeyPressed || runMannually) {
lineGroup = [...curLineGroup]
}
lineGroup.push({ size: brushSize, pts: [mouseXY(ev)] })
@ -501,7 +580,7 @@ export default function Editor(props: EditorProps) {
}, [draw, renders, redoRenders, redoLineGroups, lineGroups, original])
const undo = () => {
if (settings.runInpaintingManually && curLineGroup.length !== 0) {
if (runMannually && curLineGroup.length !== 0) {
undoStroke()
} else {
undoRender()
@ -527,14 +606,14 @@ export default function Editor(props: EditorProps) {
useKey(undoPredicate, undo, undefined, [undoStroke, undoRender])
const disableUndo = () => {
if (isInpaintingLoading) {
if (isInpainting) {
return true
}
if (renders.length > 0) {
return false
}
if (settings.runInpaintingManually) {
if (runMannually) {
if (curLineGroup.length === 0) {
return true
}
@ -575,7 +654,7 @@ export default function Editor(props: EditorProps) {
}, [draw, renders, redoRenders, redoLineGroups, lineGroups, original])
const redo = () => {
if (settings.runInpaintingManually && redoCurLines.length !== 0) {
if (runMannually && redoCurLines.length !== 0) {
redoStroke()
} else {
redoRender()
@ -603,14 +682,14 @@ export default function Editor(props: EditorProps) {
useKey(redoPredicate, redo, undefined, [redoStroke, redoRender])
const disableRedo = () => {
if (isInpaintingLoading) {
if (isInpainting) {
return true
}
if (redoRenders.length > 0) {
return false
}
if (settings.runInpaintingManually) {
if (runMannually) {
if (redoCurLines.length === 0) {
return true
}
@ -688,7 +767,7 @@ export default function Editor(props: EditorProps) {
}, [showBrush, isPanning])
// Standard Hotkeys for Brush Size
useKeyPressEvent('[', () => {
useHotKey('[', () => {
setBrushSize(currentBrushSize => {
if (currentBrushSize > 10) {
return currentBrushSize - 10
@ -700,18 +779,23 @@ export default function Editor(props: EditorProps) {
})
})
useKeyPressEvent(']', () => {
useHotKey(']', () => {
setBrushSize(currentBrushSize => {
return currentBrushSize + 10
})
})
// Manual Inpainting Hotkey
useKeyPressEvent('R', () => {
if (settings.runInpaintingManually && hadDrawSomething()) {
runInpainting()
}
})
useHotKey(
'shift+r',
() => {
if (runMannually && hadDrawSomething()) {
runInpainting()
}
},
{},
[runMannually]
)
// Toggle clean/zoom tool on spacebar.
useKeyPressEvent(
@ -792,7 +876,7 @@ export default function Editor(props: EditorProps) {
}}
>
<TransformComponent
contentClass={isInpaintingLoading ? 'editor-canvas-loading' : ''}
contentClass={isInpainting ? 'editor-canvas-loading' : ''}
contentStyle={{
visibility: initialCentered ? 'visible' : 'hidden',
}}
@ -852,10 +936,22 @@ export default function Editor(props: EditorProps) {
/>
</div>
</div>
{settings.showCroper ? (
<Croper
maxHeight={original.naturalHeight}
maxWidth={original.naturalWidth}
minHeight={Math.min(256, original.naturalHeight)}
minWidth={Math.min(256, original.naturalWidth)}
scale={scale}
/>
) : (
<></>
)}
</TransformComponent>
</TransformWrapper>
{showBrush && !isInpaintingLoading && !isPanning && (
{showBrush && !isInpainting && !isPanning && (
<div className="brush-shape" style={getBrushStyle(x, y)} />
)}
@ -867,11 +963,15 @@ export default function Editor(props: EditorProps) {
)}
<div className="editor-toolkit-panel">
<SizeSelector
onChange={onSizeLimitChange}
originalWidth={original.naturalWidth}
originalHeight={original.naturalHeight}
/>
{isSD ? (
<></>
) : (
<SizeSelector
onChange={onSizeLimitChange}
originalWidth={original.naturalWidth}
originalHeight={original.naturalHeight}
/>
)}
<Slider
label="Brush"
min={10}
@ -977,9 +1077,9 @@ export default function Editor(props: EditorProps) {
/>
</svg>
}
disabled={!hadDrawSomething() || isInpaintingLoading}
disabled={!hadDrawSomething() || isInpainting}
onClick={() => {
if (!isInpaintingLoading && hadDrawSomething()) {
if (!isInpainting && hadDrawSomething()) {
runInpainting()
}
}}

View File

@ -1,6 +1,6 @@
header {
height: 60px;
padding: 1rem 2rem;
padding: 1rem 1.5rem;
position: absolute;
top: 0;
display: flex;
@ -31,4 +31,4 @@ header {
align-items: center;
gap: 6px;
justify-self: end;
}
}

View File

@ -1,17 +1,19 @@
import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline'
import React, { useState } from 'react'
import { useRecoilState } from 'recoil'
import { fileState } from '../../store/Atoms'
import { useRecoilState, useRecoilValue } from 'recoil'
import { fileState, isSDState } from '../../store/Atoms'
import Button from '../shared/Button'
import Shortcuts from '../Shortcuts/Shortcuts'
import useResolution from '../../hooks/useResolution'
import { ThemeChanger } from './ThemeChanger'
import SettingIcon from '../Settings/SettingIcon'
import PromptInput from './PromptInput'
const Header = () => {
const [file, setFile] = useRecoilState(fileState)
const resolution = useResolution()
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
const isSD = useRecoilValue(isSDState)
const renderHeader = () => {
return (
@ -37,6 +39,8 @@ const Header = () => {
</label>
</div>
{isSD && file ? <PromptInput /> : <></>}
<div className="header-icons-wrapper">
<ThemeChanger />
{file && (

View File

@ -0,0 +1,18 @@
.prompt-wrapper {
display: flex;
gap: 12px;
}
.prompt-wrapper input {
all: unset;
border-width: 0;
border-radius: 0.5rem;
min-width: 600px;
padding: 0 0.8rem;
outline: 1px solid var(--border-color);
&:focus-visible {
border-width: 0;
outline: 1px solid var(--yellow-accent);
}
}

View File

@ -0,0 +1,44 @@
import React, { FormEvent, useState } from 'react'
import { useRecoilState } from 'recoil'
import emitter, { EVENT_PROMPT } from '../../event'
import { appState, propmtState } from '../../store/Atoms'
import Button from '../shared/Button'
import TextInput from '../shared/Input'
// TODO: show progress in input
const PromptInput = () => {
const [app, setAppState] = useRecoilState(appState)
const [prompt, setPrompt] = useRecoilState(propmtState)
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
evt.preventDefault()
evt.stopPropagation()
const target = evt.target as HTMLInputElement
setPrompt(target.value)
}
const handleRepaintClick = () => {
if (prompt.length !== 0 && !app.isInpainting) {
emitter.emit(EVENT_PROMPT)
}
}
return (
<div className="prompt-wrapper">
<TextInput
value={prompt}
onInput={handleOnInput}
placeholder="I want to repaint of..."
/>
<Button
border
onClick={handleRepaintClick}
disabled={prompt.length === 0 || app.isInpainting}
>
RePaint
</Button>
</div>
)
}
export default PromptInput

View File

@ -1,6 +1,6 @@
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import { AIModel, settingState } from '../../store/Atoms'
import { AIModel, SDSampler, settingState } from '../../store/Atoms'
import Selector from '../shared/Selector'
import { Switch, SwitchThumb } from '../shared/Switch'
import Tooltip from '../shared/Tooltip'
@ -145,6 +145,8 @@ function ModelSettingBlock() {
return undefined
case AIModel.FCF:
return renderFCFModelDesc()
case AIModel.SD14:
return undefined
default:
return <></>
}
@ -182,6 +184,12 @@ function ModelSettingBlock() {
'https://arxiv.org/abs/2208.03382',
'https://github.com/SHI-Labs/FcF-Inpainting'
)
case AIModel.SD14:
return renderModelDesc(
'Stable Diffusion',
'https://ommer-lab.com/research/latent-diffusion-models/',
'https://github.com/CompVis/stable-diffusion'
)
default:
return <></>
}

View File

@ -4,14 +4,28 @@ import SettingBlock from './SettingBlock'
interface NumberInputSettingProps {
title: string
allowFloat?: boolean
desc?: string
value: string
suffix?: string
width?: number
widthUnit?: string
disable?: boolean
onValue: (val: string) => void
}
function NumberInputSetting(props: NumberInputSettingProps) {
const { title, desc, value, suffix, onValue } = props
const {
title,
allowFloat,
desc,
value,
suffix,
onValue,
width,
widthUnit,
disable,
} = props
return (
<SettingBlock
@ -28,8 +42,10 @@ function NumberInputSetting(props: NumberInputSettingProps) {
}}
>
<NumberInput
style={{ width: '80px' }}
allowFloat={allowFloat}
style={{ width: `${width}${widthUnit}` }}
value={`${value}`}
disabled={disable}
onValue={onValue}
/>
{suffix && <span>{suffix}</span>}
@ -39,4 +55,11 @@ function NumberInputSetting(props: NumberInputSettingProps) {
)
}
NumberInputSetting.defaultProps = {
allowFloat: false,
width: 80,
widthUnit: 'px',
disable: false,
}
export default NumberInputSetting

View File

@ -1,13 +1,14 @@
import React from 'react'
import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms'
import { useRecoilState, useRecoilValue } from 'recoil'
import { isSDState, settingState } from '../../store/Atoms'
import Modal from '../shared/Modal'
import ManualRunInpaintingSettingBlock from './ManualRunInpaintingSettingBlock'
import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock'
import GraduallyInpaintingSettingBlock from './GraduallyInpaintingSettingBlock'
import DownloadMaskSettingBlock from './DownloadMaskSettingBlock'
import useHotKey from '../../hooks/useHotkey'
interface SettingModalProps {
onClose: () => void
@ -15,6 +16,7 @@ interface SettingModalProps {
export default function SettingModal(props: SettingModalProps) {
const { onClose } = props
const [setting, setSettingState] = useRecoilState(settingState)
const isSD = useRecoilValue(isSDState)
const handleOnClose = () => {
setSettingState(old => {
@ -23,6 +25,17 @@ export default function SettingModal(props: SettingModalProps) {
onClose()
}
useHotKey(
's',
() => {
setSettingState(old => {
return { ...old, show: !old.show }
})
},
{},
[]
)
return (
<Modal
onClose={handleOnClose}
@ -30,11 +43,12 @@ export default function SettingModal(props: SettingModalProps) {
className="modal-setting"
show={setting.show}
>
<ManualRunInpaintingSettingBlock />
{isSD ? <></> : <ManualRunInpaintingSettingBlock />}
<GraduallyInpaintingSettingBlock />
<DownloadMaskSettingBlock />
<ModelSettingBlock />
<HDSettingBlock />
{isSD ? <></> : <HDSettingBlock />}
</Modal>
)
}

View File

@ -1,6 +1,6 @@
import React from 'react'
import { useKeyPressEvent } from 'react-use'
import { useRecoilState } from 'recoil'
import useHotKey from '../../hooks/useHotkey'
import { shortcutsState } from '../../store/Atoms'
import Button from '../shared/Button'
@ -13,8 +13,7 @@ const Shortcuts = () => {
})
}
useKeyPressEvent('h', ev => {
ev?.preventDefault()
useHotKey('h', () => {
shortcutStateHandler()
})

View File

@ -64,7 +64,8 @@ export default function ShortcutsModal() {
<ShortCut content="Decrease Brush Size" keys={['[']} />
<ShortCut content="Increase Brush Size" keys={[']']} />
<ShortCut content="Toggle Dark Mode" keys={['Shift', 'D']} />
<ShortCut content="Toggle Hotkeys Panel" keys={['H']} />
<ShortCut content="Toggle Hotkeys Dialog" keys={['H']} />
<ShortCut content="Toggle Settings Dialog" keys={['S']} />
</div>
</Modal>
)

View File

@ -0,0 +1,57 @@
@use '../../styles/Mixins/' as *;
.side-panel {
position: absolute;
top: 68px;
right: 1.5rem;
padding: 0.3rem 0.3rem;
z-index: 4;
border-radius: 0.8rem;
border-style: solid;
border-color: var(--border-color);
border-width: 1px;
}
.side-panel-trigger {
font-family: 'WorkSans', sans-serif;
font-size: 16px;
border: 0px;
}
.side-panel-content {
position: relative;
font-family: 'WorkSans', sans-serif;
font-size: 14px;
top: 1rem;
right: 1.5rem;
padding: 1rem 1rem;
z-index: 9;
// backdrop-filter: blur(12px);
color: var(--text-color);
background-color: var(--page-bg);
border-radius: 0.8rem;
border-style: solid;
border-color: var(--border-color);
border-width: 1px;
display: flex;
flex-direction: column;
gap: 12px;
.setting-block-content {
gap: 1rem;
}
// input {
// height: 24px;
// // border-radius: 4px;
// }
// button {
// height: 28px;
// // border-radius: 4px;
// }
}

View File

@ -0,0 +1,188 @@
import React, { useState } from 'react'
import { useRecoilState } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use'
import { SDSampler, settingState } from '../../store/Atoms'
import NumberInputSetting from '../Settings/NumberInputSetting'
import SettingBlock from '../Settings/SettingBlock'
import Selector from '../shared/Selector'
import { Switch, SwitchThumb } from '../shared/Switch'
import Button from '../shared/Button'
import emitter, { EVENT_PROMPT } from '../../event'
const INPUT_WIDTH = 30
// TODO: 添加收起来的按钮
const SidePanel = () => {
const [open, toggleOpen] = useToggle(false)
const [setting, setSettingState] = useRecoilState(settingState)
const onReRunBtnClick = () => {
emitter.emit(EVENT_PROMPT)
}
return (
<div className="side-panel">
<PopoverPrimitive.Root open={open}>
<PopoverPrimitive.Trigger
className="btn-primary side-panel-trigger"
onClick={() => toggleOpen()}
>
Stable Diffusion
</PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content">
<SettingBlock
title="Show Croper"
input={
<Switch
checked={setting.showCroper}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, showCroper: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/>
{/*
<NumberInputSetting
title="Num Samples"
width={INPUT_WIDTH}
value={`${setting.sdNumSamples}`}
desc=""
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdNumSamples: val }
})
}}
/> */}
<NumberInputSetting
title="Steps"
width={INPUT_WIDTH}
value={`${setting.sdSteps}`}
desc="Large steps result in better result, but more time-consuming"
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdSteps: val }
})
}}
/>
<NumberInputSetting
title="Strength"
width={INPUT_WIDTH}
allowFloat
value={`${setting.sdStrength}`}
desc="TODO"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
console.log(val)
setSettingState(old => {
return { ...old, sdStrength: val }
})
}}
/>
<NumberInputSetting
title="Guidance Scale"
width={INPUT_WIDTH}
allowFloat
value={`${setting.sdGuidanceScale}`}
desc="TODO"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, sdGuidanceScale: val }
})
}}
/>
<SettingBlock
className="sub-setting-block"
title="Sampler"
input={
<Selector
width={80}
value={setting.sdSampler as string}
options={Object.values(SDSampler)}
onChange={val => {
const sampler = val as SDSampler
setSettingState(old => {
return { ...old, sdSampler: sampler }
})
}}
/>
}
/>
<SettingBlock
title="Seed"
input={
<div
style={{
display: 'flex',
gap: 0,
justifyContent: 'center',
alignItems: 'center',
}}
>
<Button
onClick={onReRunBtnClick}
icon={
<svg
version="1.1"
xmlns="http://www.w3.org/2000/svg"
width="20px"
height="20px"
viewBox="0 0 66.459 66.46"
>
<path
d="M65.542,11.777L33.467,0.037c-0.133-0.049-0.283-0.049-0.42,0L0.916,11.748c-0.242,0.088-0.402,0.32-0.402,0.576 l0.09,40.484c0,0.25,0.152,0.475,0.385,0.566l31.047,12.399v0.072c0,0.203,0.102,0.393,0.27,0.508 c0.168,0.111,0.379,0.135,0.57,0.062l0.385-0.154l0.385,0.154c0.072,0.028,0.15,0.045,0.227,0.045c0.121,0,0.24-0.037,0.344-0.105 c0.168-0.115,0.27-0.305,0.27-0.508v-0.072l31.047-12.399c0.232-0.093,0.385-0.316,0.385-0.568l0.027-40.453 C65.943,12.095,65.784,11.867,65.542,11.777z M32.035,63.134L3.052,51.562V15.013l28.982,11.572L32.035,63.134L32.035,63.134z M33.259,24.439L4.783,13.066l28.48-10.498l28.735,10.394L33.259,24.439z M63.465,51.562L34.484,63.134V26.585l28.981-11.572 V51.562z M14.478,38.021c0-1.692,1.35-2.528,3.016-1.867c1.665,0.663,3.016,2.573,3.016,4.269 c-0.001,1.692-1.351,2.529-3.017,1.867C15.827,41.626,14.477,39.714,14.478,38.021z M5.998,25.375c0-1.693,1.351-2.529,3.017-1.866 c1.666,0.662,3.016,2.572,3.016,4.267c0,1.695-1.351,2.529-3.017,1.867C7.347,28.979,5.998,27.069,5.998,25.375z M22.959,32.124 c0-1.694,1.351-2.53,3.017-1.867c1.666,0.663,3.016,2.573,3.016,4.267c0,1.695-1.352,2.53-3.017,1.867 C24.309,35.728,22.959,33.818,22.959,32.124z M5.995,43.103c0.001-1.692,1.351-2.529,3.017-1.867 c1.666,0.664,3.016,2.573,3.016,4.269c0,1.694-1.351,2.53-3.017,1.867C7.344,46.709,5.995,44.797,5.995,43.103z M22.957,49.853 c0.001-1.695,1.351-2.529,3.017-1.867s3.016,2.572,3.016,4.269c0,1.692-1.351,2.528-3.017,1.866 C24.306,53.458,22.957,51.546,22.957,49.853z M27.81,12.711c-0.766,1.228-3.209,2.087-5.462,1.917 c-2.253-0.169-3.46-1.301-2.695-2.528c0.765-1.227,3.207-2.085,5.461-1.916C27.365,10.352,28.573,11.484,27.81,12.711z M43.928,13.921c-0.764,1.229-3.208,2.086-5.46,1.917c-2.255-0.169-3.46-1.302-2.696-2.528c0.764-1.229,3.209-2.086,5.462-1.918 C43.485,11.563,44.693,12.695,43.928,13.921z M47.04,42.328c-1.041-1.278-0.764-3.705,0.619-5.421 c1.381-1.716,3.344-2.069,4.381-0.792c1.041,1.276,0.764,3.704-0.617,5.42S48.079,43.604,47.04,42.328z"
fill="currentColor"
/>
</svg>
}
/>
{/* 每次会从服务器返回更新该值 */}
<NumberInputSetting
title=""
width={80}
value={`${setting.sdSeed}`}
desc=""
disable={!setting.sdSeedFixed}
onValue={value => {
const val = value.length === 0 ? 0 : parseInt(value, 10)
setSettingState(old => {
return { ...old, sdSeed: val }
})
}}
/>
<Switch
checked={setting.sdSeedFixed}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, sdSeedFixed: value }
})
}}
style={{ marginLeft: '8px' }}
>
<SwitchThumb />
</Switch>
</div>
}
/>
</PopoverPrimitive.Content>
</PopoverPrimitive.Portal>
</PopoverPrimitive.Root>
</div>
)
}
export default SidePanel

View File

@ -10,6 +10,7 @@ import {
modelDownloaded,
switchModel,
} from '../adapters/inpainting'
import SidePanel from './SidePanel/SidePanel'
interface WorkspaceProps {
file: File
@ -82,6 +83,7 @@ const Workspace = ({ file }: WorkspaceProps) => {
return (
<>
<SidePanel />
<Editor file={file} />
<SettingModal onClose={onSettingClose} />
<ShortcutsModal />

View File

@ -4,6 +4,7 @@
display: grid;
grid-auto-flow: column;
column-gap: 1rem;
background-color: var(--page-bg);
color: var(--btn-text-color);
font-family: 'WorkSans', sans-serif;
width: max-content;
@ -25,6 +26,13 @@
}
.btn-primary-disabled {
background-color: var(--page-bg);
pointer-events: none;
opacity: 0.5;
}
.btn-border {
border-color: var(--btn-border-color);
border-width: 1px;
border-style: solid;
}

View File

@ -1,6 +1,7 @@
import React, { ReactNode } from 'react'
interface ButtonProps {
border?: boolean
disabled?: boolean
children?: ReactNode
className?: string
@ -17,6 +18,7 @@ interface ButtonProps {
const Button: React.FC<ButtonProps> = props => {
const {
children,
border,
className,
disabled,
icon,
@ -55,6 +57,7 @@ const Button: React.FC<ButtonProps> = props => {
toolTip ? 'info-tooltip' : '',
tooltipPosition ? `info-tooltip-${tooltipPosition}` : '',
className,
border ? `btn-border` : '',
].join(' ')}
>
{icon}
@ -65,6 +68,7 @@ const Button: React.FC<ButtonProps> = props => {
Button.defaultProps = {
disabled: false,
border: false,
}
export default Button

View File

@ -0,0 +1,42 @@
import React, { FocusEvent, InputHTMLAttributes } from 'react'
import { useRecoilState } from 'recoil'
import { appState } from '../../store/Atoms'
const TextInput = React.forwardRef<
HTMLInputElement,
InputHTMLAttributes<HTMLInputElement>
>((props: InputHTMLAttributes<HTMLInputElement>, forwardedRef) => {
const { onFocus, onBlur, ...itemProps } = props
const [_, setAppState] = useRecoilState(appState)
const handleOnFocus = (evt: FocusEvent<any>) => {
setAppState(old => {
return { ...old, disableShortCuts: true }
})
onFocus?.(evt)
}
const handleOnBlur = (evt: FocusEvent<any>) => {
setAppState(old => {
return { ...old, disableShortCuts: false }
})
onBlur?.(evt)
}
return (
<input
{...itemProps}
ref={forwardedRef}
type="text"
onFocus={handleOnFocus}
onBlur={handleOnBlur}
onKeyDown={e => {
if (e.key === 'Escape') {
e.currentTarget.blur()
}
}}
/>
)
})
export default TextInput

View File

@ -1,7 +1,9 @@
import { XIcon } from '@heroicons/react/outline'
import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil'
import * as DialogPrimitive from '@radix-ui/react-dialog'
import Button from './Button'
import { appState } from '../../store/Atoms'
export interface ModalProps {
show: boolean
@ -16,10 +18,14 @@ const Modal = React.forwardRef<
ModalProps
>((props, forwardedRef) => {
const { show, children, onClose, className, title } = props
const [_, setAppState] = useRecoilState(appState)
const onOpenChange = (open: boolean) => {
if (!open) {
onClose?.()
setAppState(old => {
return { ...old, disableShortCuts: false }
})
}
}

View File

@ -5,8 +5,13 @@
padding: 0 0.8rem;
outline: 1px solid var(--border-color);
height: 32px;
text-align: right;
&:focus-visible {
outline: 1px solid var(--yellow-accent);
}
&:disabled {
color: var(--border-color);
}
}

View File

@ -1,31 +1,44 @@
import React, { FormEvent, InputHTMLAttributes } from 'react'
import React, { FormEvent, InputHTMLAttributes, useState } from 'react'
import TextInput from './Input'
interface NumberInputProps extends InputHTMLAttributes<HTMLInputElement> {
value: string
allowFloat?: boolean
onValue?: (val: string) => void
}
const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
(props: NumberInputProps, forwardedRef) => {
const { value, onValue, ...itemProps } = props
const { value, allowFloat, onValue, ...itemProps } = props
const [innerValue, setInnerValue] = useState(value)
const handleOnInput = (evt: FormEvent<HTMLInputElement>) => {
const target = evt.target as HTMLInputElement
const val = target.value.replace(/\D/g, '')
onValue?.(val)
let val = target.value
if (allowFloat) {
val = val.replace(/[^0-9.]/g, '').replace(/(\..*?)\..*/g, '$1')
onValue?.(val)
} else {
val = val.replace(/\D/g, '')
onValue?.(val)
}
setInnerValue(val)
}
return (
<input
value={value}
<TextInput
value={innerValue}
onInput={handleOnInput}
className="number-input"
{...itemProps}
ref={forwardedRef}
type="text"
/>
)
}
)
NumberInput.defaultProps = {
allowFloat: false,
}
export default NumberInput

View File

@ -51,6 +51,7 @@ const Selector = (props: Props) => {
className="select-trigger"
style={{ width }}
ref={contentRef}
onKeyDown={e => e.preventDefault()}
>
<Select.Value />
<Select.Icon>

View File

@ -12,6 +12,7 @@ const Switch = React.forwardRef<
{...itemProps}
ref={forwardedRef}
className={`switch-root ${className}`}
onKeyDown={e => e.preventDefault()}
/>
)
})

View File

@ -1,7 +1,7 @@
.toast-viewpoint {
position: fixed;
top: 48px;
right: 0;
bottom: 48px;
right: 1.5rem;
display: flex;
flex-direction: row;
padding: 25px;

View File

@ -0,0 +1,7 @@
import mitt from 'mitt'
export const EVENT_PROMPT = 'prompt'
const emitter = mitt()
export default emitter

View File

@ -0,0 +1,22 @@
import { Options, useHotkeys } from 'react-hotkeys-hook'
import { useRecoilValue } from 'recoil'
import { appState } from '../store/Atoms'
const useHotKey = (
keys: string,
callback: any,
options?: Options,
deps?: any[]
) => {
const app = useRecoilValue(appState)
const ref = useHotkeys(
keys,
callback,
{ ...options, enabled: !app.disableShortCuts },
deps
)
return ref
}
export default useHotKey

View File

@ -9,6 +9,7 @@ export enum AIModel {
ZITS = 'zits',
MAT = 'mat',
FCF = 'fcf',
SD14 = 'sd1.4',
}
export const fileState = atom<File | undefined>({
@ -16,6 +17,89 @@ export const fileState = atom<File | undefined>({
default: undefined,
})
export interface Rect {
x: number
y: number
width: number
height: number
}
interface AppState {
disableShortCuts: boolean
isInpainting: boolean
}
export const appState = atom<AppState>({
key: 'appState',
default: {
disableShortCuts: false,
isInpainting: false,
},
})
export const propmtState = atom<string>({
key: 'promptState',
default: '',
})
export const isInpaintingState = selector({
key: 'isInpainting',
get: ({ get }) => {
const app = get(appState)
return app.isInpainting
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, isInpainting: newValue })
},
})
export const croperState = atom<Rect>({
key: 'croperState',
default: {
x: 0,
y: 0,
width: 512,
height: 512,
},
})
export const croperX = selector({
key: 'croperX',
get: ({ get }) => get(croperState).x,
set: ({ get, set }, newValue: any) => {
const rect = get(croperState)
set(croperState, { ...rect, x: newValue })
},
})
export const croperY = selector({
key: 'croperY',
get: ({ get }) => get(croperState).y,
set: ({ get, set }, newValue: any) => {
const rect = get(croperState)
set(croperState, { ...rect, y: newValue })
},
})
export const croperHeight = selector({
key: 'croperHeight',
get: ({ get }) => get(croperState).height,
set: ({ get, set }, newValue: any) => {
const rect = get(croperState)
set(croperState, { ...rect, height: newValue })
},
})
export const croperWidth = selector({
key: 'croperWidth',
get: ({ get }) => get(croperState).width,
set: ({ get, set }, newValue: any) => {
const rect = get(croperState)
set(croperState, { ...rect, width: newValue })
},
})
interface ToastAtomState {
open: boolean
desc: string
@ -50,6 +134,7 @@ type ModelsHDSettings = { [key in AIModel]: HDSettings }
export interface Settings {
show: boolean
showCroper: boolean
downloadMask: boolean
graduallyInpainting: boolean
runInpaintingManually: boolean
@ -62,6 +147,16 @@ export interface Settings {
// For ZITS
zitsWireframe: boolean
// For SD
sdMode: SDMode
sdStrength: number
sdSteps: number
sdGuidanceScale: number
sdSampler: SDSampler
sdSeed: number
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
sdNumSamples: number
}
const defaultHDSettings: ModelsHDSettings = {
@ -100,10 +195,28 @@ const defaultHDSettings: ModelsHDSettings = {
hdStrategyCropMargin: 128,
enabled: false,
},
[AIModel.SD14]: {
hdStrategy: HDStrategy.ORIGINAL,
hdStrategyResizeLimit: 768,
hdStrategyCropTrigerSize: 512,
hdStrategyCropMargin: 128,
enabled: true,
},
}
export enum SDSampler {
ddim = 'ddim',
}
export enum SDMode {
text2img = 'text2img',
img2img = 'img2img',
inpainting = 'inpainting',
}
export const settingStateDefault: Settings = {
show: false,
showCroper: false,
downloadMask: false,
graduallyInpainting: true,
runInpaintingManually: false,
@ -114,6 +227,16 @@ export const settingStateDefault: Settings = {
ldmSampler: LDMSampler.plms,
zitsWireframe: true,
// SD
sdMode: SDMode.inpainting,
sdStrength: 0.75,
sdSteps: 50,
sdGuidanceScale: 7.5,
sdSampler: SDSampler.ddim,
sdSeed: 42,
sdSeedFixed: true,
sdNumSamples: 1,
}
const localStorageEffect =
@ -164,3 +287,20 @@ export const hdSettingsState = selector({
})
},
})
export const isSDState = selector({
key: 'isSD',
get: ({ get }) => {
const settings = get(settingState)
return settings.model === AIModel.SD14
},
})
export const runManuallyState = selector({
key: 'runManuallyState',
get: ({ get }) => {
const settings = get(settingState)
const isSD = get(isSDState)
return settings.runInpaintingManually || isSD
},
})

View File

@ -6,6 +6,7 @@
--page-bg-light: rgb(255, 255, 255, 0.5);
--page-text-color: #040404;
--yellow-accent: #ffcc00;
--yellow-accent-light: #ffcc0055;
--link-color: rgb(0, 0, 0);
--border-color: rgb(100, 100, 120);
--border-color-light: rgba(100, 100, 120, 0.5);
@ -57,4 +58,6 @@
--box-shadow: inset 0 0.5px rgba(255, 255, 255, 0.1),
inset 0 1px 5px hsl(210 16.7% 97.6%), 0px 0px 0px 0.5px hsl(205 10.7% 78%),
0px 2px 1px -1px hsl(205 10.7% 78%), 0 1px hsl(205 10.7% 78%);
--croper-bg: rgba(0, 0, 0, 0.5);
}

View File

@ -6,6 +6,7 @@
--page-bg-light: #04040488;
--page-text-color: #f9f9f9;
--yellow-accent: #ffcc00;
--yellow-accent-light: #ffcc0055;
--link-color: var(--yellow-accent);
--border-color: rgb(100, 100, 120);
--border-color-light: rgba(102, 102, 102);
@ -55,4 +56,6 @@
--box-shadow: inset 0 0.5px rgba(255, 255, 255, 0.1),
inset 0 1px 5px hsl(195 7.1% 11%), 0px 0px 0px 0.5px hsl(207 5.6% 31.6%),
0px 2px 1px -1px hsl(207 5.6% 31.6%), 0 1px hsl(207 5.6% 31.6%);
--croper-bg: rgba(0, 0, 0, 0.5);
}

View File

@ -9,9 +9,12 @@
@use '../components/Editor/Editor';
@use '../components/LandingPage/LandingPage';
@use '../components/Header/Header';
@use '../components/Header/PromptInput';
@use '../components/Header/ThemeChanger';
@use '../components/Shortcuts/Shortcuts';
@use '../components/Settings/Settings.scss';
@use '../components/SidePanel/SidePanel.scss';
@use '../components/Croper/Croper.scss';
// Shared
@use '../components/FileSelect/FileSelect';

View File

@ -1241,6 +1241,26 @@
minimatch "^3.0.4"
strip-json-comments "^3.1.1"
"@floating-ui/core@^0.7.3":
version "0.7.3"
resolved "https://registry.npmmirror.com/@floating-ui/core/-/core-0.7.3.tgz#d274116678ffae87f6b60e90f88cc4083eefab86"
integrity sha512-buc8BXHmG9l82+OQXOFU3Kr2XQx9ys01U/Q9HMIrZ300iLc8HLMgh7dcCqgYzAzf4BkoQvDcXf5Y+CuEZ5JBYg==
"@floating-ui/dom@^0.5.3":
version "0.5.4"
resolved "https://registry.npmmirror.com/@floating-ui/dom/-/dom-0.5.4.tgz#4eae73f78bcd4bd553ae2ade30e6f1f9c73fe3f1"
integrity sha512-419BMceRLq0RrmTSDxn8hf9R3VCJv2K9PUfugh5JyEFmdjzDo+e8U5EdR8nzKq8Yj1htzLm3b6eQEEam3/rrtg==
dependencies:
"@floating-ui/core" "^0.7.3"
"@floating-ui/react-dom@0.7.2":
version "0.7.2"
resolved "https://registry.npmmirror.com/@floating-ui/react-dom/-/react-dom-0.7.2.tgz#0bf4ceccb777a140fc535c87eb5d6241c8e89864"
integrity sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg==
dependencies:
"@floating-ui/dom" "^0.5.3"
use-isomorphic-layout-effect "^1.1.1"
"@gar/promisify@^1.0.1":
version "1.1.2"
resolved "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.2.tgz"
@ -1566,6 +1586,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-1.0.0.tgz#e1d8ef30b10ea10e69c76e896f608d9276352253"
integrity sha512-3e7rn8FDMin4CgeL7Z/49smCA3rFYY3Ha2rUQ7HRWFadS5iCRw08ZgVT1LaNTCNqgvrUiyczLflrVrF0SRQtNA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-arrow@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-arrow/-/react-arrow-0.1.4.tgz#a871448a418cd3507d83840fdd47558cb961672b"
@ -1574,6 +1601,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "0.1.4"
"@radix-ui/react-arrow@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-arrow/-/react-arrow-1.0.0.tgz#c461f4c2cab3317e3d42a1ae62910a4cbb0192a1"
integrity sha512-1MUuv24HCdepi41+qfv125EwMuxgQ+U+h0A9K3BjCO/J8nVRREKHHpkD9clwfnjEDk9hgGzCnff4aUKCPiRepw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "1.0.0"
"@radix-ui/react-collection@0.1.5-rc.18":
version "0.1.5-rc.18"
resolved "https://registry.npmmirror.com/@radix-ui/react-collection/-/react-collection-0.1.5-rc.18.tgz#4dc03a8f464643748c0dad781b472f149d671d5c"
@ -1599,6 +1634,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-compose-refs/-/react-compose-refs-1.0.0.tgz#37595b1f16ec7f228d698590e78eeed18ff218ae"
integrity sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-context@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-0.1.1.tgz#06996829ea124d9a1bc1dbe3e51f33588fab0875"
@ -1613,6 +1655,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-context@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-context/-/react-context-1.0.0.tgz#f38e30c5859a9fb5e9aa9a9da452ee3ed9e0aee0"
integrity sha512-1pVM9RfOQ+n/N5PJK33kRSKsr1glNxomxONs5c49MliinBY6Yw2Q995qfBUUo0/Mbg05B/sGA0gkgPI7kmSHBg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-dialog@0.1.8-rc.25":
version "0.1.8-rc.25"
resolved "https://registry.npmmirror.com/@radix-ui/react-dialog/-/react-dialog-0.1.8-rc.25.tgz#dea6af32268b34070346ed5d6d609ff699a1de43"
@ -1667,6 +1716,18 @@
"@radix-ui/react-use-callback-ref" "0.1.1-rc.18"
"@radix-ui/react-use-escape-keydown" "0.1.1-rc.18"
"@radix-ui/react-dismissable-layer@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.0.tgz#35b7826fa262fd84370faef310e627161dffa76b"
integrity sha512-n7kDRfx+LB1zLueRDvZ1Pd0bxdJWDUZNQ/GWoxDn2prnuJKRdxsjulejX/ePkOsLi2tTm6P24mDqlMSgQpsT6g==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "1.0.0"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-primitive" "1.0.0"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-escape-keydown" "1.0.0"
"@radix-ui/react-focus-guards@0.1.1-rc.18":
version "0.1.1-rc.18"
resolved "https://registry.npmmirror.com/@radix-ui/react-focus-guards/-/react-focus-guards-0.1.1-rc.18.tgz#f0e2ebd3cbfd363a71682e3234b274ab7d7df4ce"
@ -1674,6 +1735,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-focus-guards@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.0.tgz#339c1c69c41628c1a5e655f15f7020bf11aa01fa"
integrity sha512-UagjDk4ijOAnGu4WMUPj9ahi7/zJJqNZ9ZAiGPp7waUWJO0O1aWXi/udPphI0IUjvrhBsZJGSN66dR2dsueLWQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-focus-scope@0.1.5-rc.18":
version "0.1.5-rc.18"
resolved "https://registry.npmmirror.com/@radix-ui/react-focus-scope/-/react-focus-scope-0.1.5-rc.18.tgz#e26a0317130687fd3668af8ec68e19e04dc7668f"
@ -1684,6 +1752,16 @@
"@radix-ui/react-primitive" "0.1.5-rc.18"
"@radix-ui/react-use-callback-ref" "0.1.1-rc.18"
"@radix-ui/react-focus-scope@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.0.tgz#95a0c1188276dc8933b1eac5f1cdb6471e01ade5"
integrity sha512-C4SWtsULLGf/2L4oGeIHlvWQx7Rf+7cX/vKOAD2dXW0A1b5QXwi3wWeaEgW+wn+SEVrraMUk05vLU9fZZz5HbQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-primitive" "1.0.0"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-id@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-0.1.5.tgz#010d311bedd5a2884c1e9bb6aaaa4e6cc1d1d3b8"
@ -1700,6 +1778,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "0.1.1-rc.18"
"@radix-ui/react-id@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-id/-/react-id-1.0.0.tgz#8d43224910741870a45a8c9d092f25887bb6d11e"
integrity sha512-Q6iAB/U7Tq3NTolBBQbHTgclPmGWE3OlktGGqrClPozSw4vkQ1DfQAOtzgRPecKsMdJINE05iaoDUG8tRzCBjw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-label@0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-label/-/react-label-0.1.5.tgz#12cd965bfc983e0148121d4c99fb8e27a917c45c"
@ -1722,6 +1808,28 @@
"@radix-ui/react-id" "0.1.6-rc.18"
"@radix-ui/react-primitive" "0.1.5-rc.18"
"@radix-ui/react-popover@^1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-popover/-/react-popover-1.0.0.tgz#5ee72013089fdf9038417fc1eb98a749c17457fd"
integrity sha512-osxFFO0TiZ9ABpEOitZu0R1Fdd+tSpJgAqLZxRLLdZQ7ya0onSODcITp5hXDVuYQeVXH6pKEBGwXN6ZGjZ0a5g==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "1.0.0"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-dismissable-layer" "1.0.0"
"@radix-ui/react-focus-guards" "1.0.0"
"@radix-ui/react-focus-scope" "1.0.0"
"@radix-ui/react-id" "1.0.0"
"@radix-ui/react-popper" "1.0.0"
"@radix-ui/react-portal" "1.0.0"
"@radix-ui/react-presence" "1.0.0"
"@radix-ui/react-primitive" "1.0.0"
"@radix-ui/react-slot" "1.0.0"
"@radix-ui/react-use-controllable-state" "1.0.0"
aria-hidden "^1.1.1"
react-remove-scroll "2.5.4"
"@radix-ui/react-popper@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-popper/-/react-popper-0.1.4.tgz#dfc055dcd7dfae6a2eff7a70d333141d15a5d029"
@ -1737,6 +1845,22 @@
"@radix-ui/react-use-size" "0.1.1"
"@radix-ui/rect" "0.1.1"
"@radix-ui/react-popper@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-popper/-/react-popper-1.0.0.tgz#fb4f937864bf39c48f27f55beee61fa9f2bef93c"
integrity sha512-k2dDd+1Wl0XWAMs9ZvAxxYsB9sOsEhrFQV4CINd7IUZf0wfdye4OHen9siwxvZImbzhgVeKTJi68OQmPRvVdMg==
dependencies:
"@babel/runtime" "^7.13.10"
"@floating-ui/react-dom" "0.7.2"
"@radix-ui/react-arrow" "1.0.0"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-primitive" "1.0.0"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-use-rect" "1.0.0"
"@radix-ui/react-use-size" "1.0.0"
"@radix-ui/rect" "1.0.0"
"@radix-ui/react-portal@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-0.1.4.tgz#17bdce3d7f1a9a0b35cb5e935ab8bc562441a7d2"
@ -1755,6 +1879,14 @@
"@radix-ui/react-primitive" "0.1.5-rc.18"
"@radix-ui/react-use-layout-effect" "0.1.1-rc.18"
"@radix-ui/react-portal@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-portal/-/react-portal-1.0.0.tgz#7220b66743394fabb50c55cb32381395cc4a276b"
integrity sha512-a8qyFO/Xb99d8wQdu4o7qnigNjTPG123uADNecz0eX4usnQEj7o+cG4ZX4zkqq98NYekT7UoEQIjxBNWIFuqTA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "1.0.0"
"@radix-ui/react-presence@0.1.2":
version "0.1.2"
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-0.1.2.tgz#9f11cce3df73cf65bc348e8b76d891f0d54c1fe3"
@ -1773,6 +1905,15 @@
"@radix-ui/react-compose-refs" "0.1.1-rc.18"
"@radix-ui/react-use-layout-effect" "0.1.1-rc.18"
"@radix-ui/react-presence@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-presence/-/react-presence-1.0.0.tgz#814fe46df11f9a468808a6010e3f3ca7e0b2e84a"
integrity sha512-A+6XEvN01NfVWiKu38ybawfHsBjWum42MRPnEuqPsBZ4eV7e/7K321B5VgYMPv3Xx5An6o1/l9ZuDBgmcmWK3w==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-primitive@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-0.1.4.tgz#6c233cf08b0cb87fecd107e9efecb3f21861edc1"
@ -1789,6 +1930,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-slot" "0.1.3-rc.18"
"@radix-ui/react-primitive@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-primitive/-/react-primitive-1.0.0.tgz#376cd72b0fcd5e0e04d252ed33eb1b1f025af2b0"
integrity sha512-EyXe6mnRlHZ8b6f4ilTDrXmkLShICIuOTTj0GX4w1rp+wSxf3+TD05u1UOITC8VsJ2a9nwHvdXtOXEOl0Cw/zQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-slot" "1.0.0"
"@radix-ui/react-select@0.1.2-rc.27":
version "0.1.2-rc.27"
resolved "https://registry.npmmirror.com/@radix-ui/react-select/-/react-select-0.1.2-rc.27.tgz#91948d482b3db8cf83172838dfae0f4bedec9566"
@ -1831,6 +1980,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "0.1.1-rc.18"
"@radix-ui/react-slot@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-slot/-/react-slot-1.0.0.tgz#7fa805b99891dea1e862d8f8fbe07f4d6d0fd698"
integrity sha512-3mrKauI/tWXo1Ll+gN5dHcxDPdm/Df1ufcDLCecn+pnCIVcdWE7CujXo8QaXOWRJyZyQWWbpB8eFwHzWXlv5mQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-switch@^0.1.5":
version "0.1.5"
resolved "https://registry.npmmirror.com/@radix-ui/react-switch/-/react-switch-0.1.5.tgz#071ffa19a17a47fdc5c5e6f371bd5901c9fef2f4"
@ -1915,6 +2072,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.0.tgz#9e7b8b6b4946fe3cbe8f748c82a2cce54e7b6a90"
integrity sha512-GZtyzoHz95Rhs6S63D2t/eqvdFCm7I+yHMLVQheKM7nBD8mbZIt+ct1jz4536MDnaOGKIxynJ8eHTkVGVVkoTg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-controllable-state@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-0.1.0.tgz#4fced164acfc69a4e34fb9d193afdab973a55de1"
@ -1931,6 +2095,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "0.1.1-rc.18"
"@radix-ui/react-use-controllable-state@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.0.0.tgz#a64deaafbbc52d5d407afaa22d493d687c538b7f"
integrity sha512-FohDoZvk3mEXh9AWAVyRTYR4Sq7/gavuofglmiXB2g1aKyboUD4YtgWxKj8O5n+Uak52gXQ4wKz5IFST4vtJHg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-escape-keydown@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-0.1.0.tgz#dc80cb3753e9d1bd992adbad9a149fb6ea941874"
@ -1947,6 +2119,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "0.1.1-rc.18"
"@radix-ui/react-use-escape-keydown@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.0.0.tgz#aef375db4736b9de38a5a679f6f49b45a060e5d1"
integrity sha512-JwfBCUIfhXRxKExgIqGa4CQsiMemo1Xt0W/B4ei3fpzpvPENKpMKQ8mZSB6Acj3ebrAEgi2xiQvcI1PAAodvyg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-layout-effect@0.1.0":
version "0.1.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-0.1.0.tgz#ebf71bd6d2825de8f1fbb984abf2293823f0f223"
@ -1961,6 +2141,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.0.tgz#2fc19e97223a81de64cd3ba1dc42ceffd82374dc"
integrity sha512-6Tpkq+R6LOlmQb1R5NNETLG0B4YP0wc+klfXafpUCj6JGyaUc8il7/kUZ7m59rGbXGczE9Bs+iz2qloqsZBduQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-previous@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-previous/-/react-use-previous-0.1.1.tgz#0226017f72267200f6e832a7103760e96a6db5d0"
@ -1983,6 +2170,14 @@
"@babel/runtime" "^7.13.10"
"@radix-ui/rect" "0.1.1"
"@radix-ui/react-use-rect@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-rect/-/react-use-rect-1.0.0.tgz#b040cc88a4906b78696cd3a32b075ed5b1423b3e"
integrity sha512-TB7pID8NRMEHxb/qQJpvSt3hQU4sqNPM1VCTjTRjEOa7cEop/QMuq8S6fb/5Tsz64kqSvB9WnwsDHtjnrM9qew==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/rect" "1.0.0"
"@radix-ui/react-use-size@0.1.1":
version "0.1.1"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-size/-/react-use-size-0.1.1.tgz#f6b75272a5d41c3089ca78c8a2e48e5f204ef90f"
@ -1990,6 +2185,14 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-size@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/react-use-size/-/react-use-size-1.0.0.tgz#a0b455ac826749419f6354dc733e2ca465054771"
integrity sha512-imZ3aYcoYCKhhgNpkNDh/aTiU05qw9hX+HHI1QDBTyIlcFjgeFlKKySNGMwTp7nYFLQg/j0VA2FmCY4WPDDHMg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-visually-hidden@0.1.4":
version "0.1.4"
resolved "https://registry.npmmirror.com/@radix-ui/react-visually-hidden/-/react-visually-hidden-0.1.4.tgz#6c75eae34fb5d084b503506fbfc05587ced05f03"
@ -2013,6 +2216,13 @@
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/rect@1.0.0":
version "1.0.0"
resolved "https://registry.npmmirror.com/@radix-ui/rect/-/rect-1.0.0.tgz#0dc8e6a829ea2828d53cbc94b81793ba6383bf3c"
integrity sha512-d0O68AYy/9oeEy1DdC07bz1/ZXX+DqCskRd3i4JzLSTXwefzaepQrKjXC7aNM8lTHjFLDO0pDgaEiQ7jEk+HVg==
dependencies:
"@babel/runtime" "^7.13.10"
"@rollup/plugin-node-resolve@^7.1.1":
version "7.1.3"
resolved "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-7.1.3.tgz"
@ -2055,6 +2265,11 @@
dependencies:
"@sinonjs/commons" "^1.7.0"
"@socket.io/component-emitter@~3.1.0":
version "3.1.0"
resolved "https://registry.npmmirror.com/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz#96116f2a912e0c02817345b3c10751069920d553"
integrity sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==
"@surma/rollup-plugin-off-main-thread@^1.1.1":
version "1.4.2"
resolved "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-1.4.2.tgz"
@ -4621,6 +4836,13 @@ debug@^3.1.1, debug@^3.2.6, debug@^3.2.7:
dependencies:
ms "^2.1.1"
debug@~4.3.1, debug@~4.3.2:
version "4.3.4"
resolved "https://registry.npmmirror.com/debug/-/debug-4.3.4.tgz#1319f6579357f2338d3337d2cdd4914bb5dcc865"
integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==
dependencies:
ms "2.1.2"
decamelize@^1.2.0:
version "1.2.0"
resolved "https://registry.npmjs.org/decamelize/-/decamelize-1.2.0.tgz"
@ -5004,6 +5226,22 @@ end-of-stream@^1.0.0, end-of-stream@^1.1.0:
dependencies:
once "^1.4.0"
engine.io-client@~6.2.1:
version "6.2.2"
resolved "https://registry.npmmirror.com/engine.io-client/-/engine.io-client-6.2.2.tgz#c6c5243167f5943dcd9c4abee1bfc634aa2cbdd0"
integrity sha512-8ZQmx0LQGRTYkHuogVZuGSpDqYZtCM/nv8zQ68VZ+JkOpazJ7ICdsSpaO6iXwvaU30oFg5QJOJWj8zWqhbKjkQ==
dependencies:
"@socket.io/component-emitter" "~3.1.0"
debug "~4.3.1"
engine.io-parser "~5.0.3"
ws "~8.2.3"
xmlhttprequest-ssl "~2.0.0"
engine.io-parser@~5.0.3:
version "5.0.4"
resolved "https://registry.npmmirror.com/engine.io-parser/-/engine.io-parser-5.0.4.tgz#0b13f704fa9271b3ec4f33112410d8f3f41d0fc0"
integrity sha512-+nVFp+5z1E3HcToEnO7ZIj3g+3k9389DvWtvJZz0T6/eOCPIyyxehFcedoYrZQrp0LgQbD9pPXhpMBKMd5QURg==
enhanced-resolve@^4.3.0:
version "4.5.0"
resolved "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-4.5.0.tgz"
@ -6215,6 +6453,11 @@ hosted-git-info@^2.1.4:
resolved "https://registry.npmjs.org/hosted-git-info/-/hosted-git-info-2.8.9.tgz"
integrity sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==
hotkeys-js@3.9.4:
version "3.9.4"
resolved "https://registry.npmmirror.com/hotkeys-js/-/hotkeys-js-3.9.4.tgz#ce1aa4c3a132b6a63a9dd5644fc92b8a9b9cbfb9"
integrity sha512-2zuLt85Ta+gIyvs4N88pCYskNrxf1TFv3LR9t5mdAZIX8BcgQQ48F2opUptvHa6m8zsy5v/a0i9mWzTrlNWU0Q==
hpack.js@^2.1.6:
version "2.1.6"
resolved "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz"
@ -7785,16 +8028,11 @@ lodash.uniq@^4.5.0:
resolved "https://registry.npmjs.org/lodash.uniq/-/lodash.uniq-4.5.0.tgz"
integrity sha1-0CJTc662Uq3BvILklFM5qEJ1R3M=
"lodash@>=3.5 <5", lodash@^4.17.11, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.20, lodash@^4.17.5, lodash@^4.7.0:
"lodash@>=3.5 <5", lodash@^4.17.11, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.20, lodash@^4.17.21, lodash@^4.17.5, lodash@^4.7.0:
version "4.17.21"
resolved "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
lodash@^4.17.21:
version "4.17.21"
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
loglevel@^1.6.8:
version "1.7.1"
resolved "https://registry.npmjs.org/loglevel/-/loglevel-1.7.1.tgz"
@ -8095,6 +8333,11 @@ mississippi@^3.0.0:
stream-each "^1.1.0"
through2 "^2.0.0"
mitt@^3.0.0:
version "3.0.0"
resolved "https://registry.npmmirror.com/mitt/-/mitt-3.0.0.tgz#69ef9bd5c80ff6f57473e8d89326d01c414be0bd"
integrity sha512-7dX2/10ITVyqh4aOSVI9gdape+t9l2/8QxHrFmUXu4EEUpdlxl6RudZUPZoc+zuY2hk1j7XxVroIVIan/pD/SQ==
mixin-deep@^1.2.0:
version "1.3.2"
resolved "https://registry.npmjs.org/mixin-deep/-/mixin-deep-1.3.2.tgz"
@ -9924,6 +10167,13 @@ react-error-overlay@^6.0.9:
resolved "https://registry.npmjs.org/react-error-overlay/-/react-error-overlay-6.0.9.tgz"
integrity sha512-nQTTcUu+ATDbrSD1BZHr5kgSD4oF8OFjxun8uAaL8RwPBacGBNPf/yAuVVdx17N8XNzRDMrZ9XcKZHCjPW+9ew==
react-hotkeys-hook@^3.4.7:
version "3.4.7"
resolved "https://registry.npmmirror.com/react-hotkeys-hook/-/react-hotkeys-hook-3.4.7.tgz#e16a0a85f59feed9f48d12cfaf166d7df4c96b7a"
integrity sha512-+bbPmhPAl6ns9VkXkNNyxlmCAIyDAcWbB76O4I0ntr3uWCRuIQf/aRLartUahe9chVMPj+OEzzfk3CQSjclUEQ==
dependencies:
hotkeys-js "3.9.4"
react-is@^16.8.1:
version "16.13.1"
resolved "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz"
@ -9947,6 +10197,25 @@ react-remove-scroll-bar@^2.3.0:
react-style-singleton "^2.2.0"
tslib "^2.0.0"
react-remove-scroll-bar@^2.3.3:
version "2.3.3"
resolved "https://registry.npmmirror.com/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.3.tgz#e291f71b1bb30f5f67f023765b7435f4b2b2cd94"
integrity sha512-i9GMNWwpz8XpUpQ6QlevUtFjHGqnPG4Hxs+wlIJntu/xcsZVEpJcIV71K3ZkqNy2q3GfgvkD7y6t/Sv8ofYSbw==
dependencies:
react-style-singleton "^2.2.1"
tslib "^2.0.0"
react-remove-scroll@2.5.4:
version "2.5.4"
resolved "https://registry.npmmirror.com/react-remove-scroll/-/react-remove-scroll-2.5.4.tgz#afe6491acabde26f628f844b67647645488d2ea0"
integrity sha512-xGVKJJr0SJGQVirVFAUZ2k1QLyO6m+2fy0l8Qawbp5Jgrv3DeLalrfMNBFSlmz5kriGGzsVBtGVnf4pTKIhhWA==
dependencies:
react-remove-scroll-bar "^2.3.3"
react-style-singleton "^2.2.1"
tslib "^2.1.0"
use-callback-ref "^1.3.0"
use-sidecar "^1.1.2"
react-remove-scroll@^2.4.0:
version "2.5.1"
resolved "https://registry.npmmirror.com/react-remove-scroll/-/react-remove-scroll-2.5.1.tgz#28c318c2e076040e5d6172bf28aab2916ad89b46"
@ -10033,6 +10302,15 @@ react-style-singleton@^2.2.0:
invariant "^2.2.4"
tslib "^2.0.0"
react-style-singleton@^2.2.1:
version "2.2.1"
resolved "https://registry.npmmirror.com/react-style-singleton/-/react-style-singleton-2.2.1.tgz#f99e420492b2d8f34d38308ff660b60d0b1205b4"
integrity sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==
dependencies:
get-nonce "^1.0.0"
invariant "^2.2.4"
tslib "^2.0.0"
react-universal-interface@^0.6.2:
version "0.6.2"
resolved "https://registry.npmjs.org/react-universal-interface/-/react-universal-interface-0.6.2.tgz"
@ -10845,6 +11123,24 @@ snapdragon@^0.8.1:
source-map-resolve "^0.5.0"
use "^3.1.0"
socket.io-client@^4.5.2:
version "4.5.2"
resolved "https://registry.npmmirror.com/socket.io-client/-/socket.io-client-4.5.2.tgz#9481518c560388c980c88b01e3cf62f367f04c96"
integrity sha512-naqYfFu7CLDiQ1B7AlLhRXKX3gdeaIMfgigwavDzgJoIUYulc1qHH5+2XflTsXTPY7BlPH5rppJyUjhjrKQKLg==
dependencies:
"@socket.io/component-emitter" "~3.1.0"
debug "~4.3.2"
engine.io-client "~6.2.1"
socket.io-parser "~4.2.0"
socket.io-parser@~4.2.0:
version "4.2.1"
resolved "https://registry.npmmirror.com/socket.io-parser/-/socket.io-parser-4.2.1.tgz#01c96efa11ded938dcb21cbe590c26af5eff65e5"
integrity sha512-V4GrkLy+HeF1F/en3SpUaM+7XxYXpuMUWLGde1kSSh5nQMN4hLrbPIkD+otwh6q9R6NOQBN4AMaOZ2zVjui82g==
dependencies:
"@socket.io/component-emitter" "~3.1.0"
debug "~4.3.1"
sockjs-client@^1.5.0:
version "1.5.2"
resolved "https://registry.npmjs.org/sockjs-client/-/sockjs-client-1.5.2.tgz"
@ -11855,6 +12151,11 @@ use-callback-ref@^1.3.0:
dependencies:
tslib "^2.0.0"
use-isomorphic-layout-effect@^1.1.1:
version "1.1.2"
resolved "https://registry.npmmirror.com/use-isomorphic-layout-effect/-/use-isomorphic-layout-effect-1.1.2.tgz#497cefb13d863d687b08477d9e5a164ad8c1a6fb"
integrity sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA==
use-sidecar@^1.1.2:
version "1.1.2"
resolved "https://registry.npmmirror.com/use-sidecar/-/use-sidecar-1.1.2.tgz#2f43126ba2d7d7e117aa5855e5d8f0276dfe73c2"
@ -12410,6 +12711,11 @@ ws@^7.4.6:
resolved "https://registry.npmjs.org/ws/-/ws-7.5.5.tgz"
integrity sha512-BAkMFcAzl8as1G/hArkxOxq3G7pjUqQ3gzYbLL0/5zNkph70e+lCoxBGnm6AW1+/aiNeV4fnKqZ8m4GZewmH2w==
ws@~8.2.3:
version "8.2.3"
resolved "https://registry.npmmirror.com/ws/-/ws-8.2.3.tgz#63a56456db1b04367d0b721a0b80cae6d8becbba"
integrity sha512-wBuoj1BDpC6ZQ1B7DWQBYVLphPWkm8i9Y0/3YdHjHKHiohOJ1ws+3OccDWtH+PoC9DZD5WOTrJvNbWvjS6JWaA==
xml-name-validator@^3.0.0:
version "3.0.0"
resolved "https://registry.npmjs.org/xml-name-validator/-/xml-name-validator-3.0.0.tgz"
@ -12420,6 +12726,11 @@ xmlchars@^2.2.0:
resolved "https://registry.npmjs.org/xmlchars/-/xmlchars-2.2.0.tgz"
integrity sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==
xmlhttprequest-ssl@~2.0.0:
version "2.0.0"
resolved "https://registry.npmmirror.com/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz#91360c86b914e67f44dce769180027c0da618c67"
integrity sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A==
xtend@^4.0.0, xtend@~4.0.1:
version "4.0.2"
resolved "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz"

View File

@ -14,17 +14,17 @@ class InpaintModel:
pad_mod = 8
pad_to_square = False
def __init__(self, device):
def __init__(self, device, **kwargs):
"""
Args:
device:
"""
self.device = device
self.init_model(device)
self.init_model(device, **kwargs)
@abc.abstractmethod
def init_model(self, device):
def init_model(self, device, **kwargs):
...
@staticmethod
@ -36,15 +36,19 @@ class InpaintModel:
def forward(self, image, mask, config: Config):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W] 255 masks 区域
masks: [H, W, 1] 255 masks 区域
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
pad_image = pad_img_to_modulo(
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
)
pad_mask = pad_img_to_modulo(
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
)
logger.info(f"final forward pad size: {pad_image.shape}")
@ -81,18 +85,30 @@ class InpaintModel:
elif config.hd_strategy == HDStrategy.RESIZE:
if max(image.shape) > config.hd_strategy_resize_limit:
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
downsize_image = resize_max_size(
image, size_limit=config.hd_strategy_resize_limit
)
downsize_mask = resize_max_size(
mask, size_limit=config.hd_strategy_resize_limit
)
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
logger.info(
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
)
inpaint_result = self._pad_forward(
downsize_image, downsize_mask, config
)
# only paste masked area result
inpaint_result = cv2.resize(inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC)
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
original_pixel_indices
]
if inpaint_result is None:
inpaint_result = self._pad_forward(image, mask, config)
@ -133,11 +149,11 @@ class InpaintModel:
if _l < 0:
r += abs(_l)
if _r > img_w:
l -= (_r - img_w)
l -= _r - img_w
if _t < 0:
b += abs(_t)
if _b > img_h:
t -= (_b - img_h)
t -= _b - img_h
l = max(l, 0)
r = min(r, img_w)

View File

@ -1135,7 +1135,7 @@ class FcF(InpaintModel):
pad_mod = 512
pad_to_square = True
def init_model(self, device):
def init_model(self, device, **kwargs):
seed = 0
random.seed(seed)
np.random.seed(seed)

View File

@ -18,16 +18,7 @@ LAMA_MODEL_URL = os.environ.get(
class LaMa(InpaintModel):
pad_mod = 8
def __init__(self, device):
"""
Args:
device:
"""
super().__init__(device)
self.device = device
def init_model(self, device):
def init_model(self, device, **kwargs):
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):

View File

@ -227,7 +227,7 @@ class LDM(InpaintModel):
super().__init__(device)
self.device = device
def init_model(self, device):
def init_model(self, device, **kwargs):
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)

View File

@ -1405,7 +1405,7 @@ class MAT(InpaintModel):
pad_mod = 512
pad_to_square = True
def init_model(self, device):
def init_model(self, device, **kwargs):
seed = 240 # pick up a random number
random.seed(seed)
np.random.seed(seed)

152
lama_cleaner/model/sd.py Normal file
View File

@ -0,0 +1,152 @@
import random
import PIL.Image
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import norm_img
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
#
#
# def preprocess_image(image):
# w, h = image.size
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
# image = image.resize((w, h), resample=PIL.Image.LANCZOS)
# image = np.array(image).astype(np.float32) / 255.0
# image = image[None].transpose(0, 3, 1, 2)
# image = torch.from_numpy(image)
# # [-1, 1]
# return 2.0 * image - 1.0
#
#
# def preprocess_mask(mask):
# mask = mask.convert("L")
# w, h = mask.size
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
# mask = np.array(mask).astype(np.float32) / 255.0
# mask = np.tile(mask, (4, 1, 1))
# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
# mask = 1 - mask # repaint white, keep black
# mask = torch.from_numpy(mask)
# return mask
class SD(InpaintModel):
pad_mod = 32
min_size = 512
def init_model(self, device: torch.device, **kwargs):
# return
from .sd_pipeline import StableDiffusionInpaintPipeline
self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=kwargs["hf_access_token"],
)
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
self.model = self.model.to(device)
self.callbacks = kwargs.pop("callbacks", None)
@torch.cuda.amp.autocast()
def forward(self, image, mask, config: Config):
# return image
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
# image = norm_img(image) # [0, 1]
# image = image * 2 - 1 # [0, 1] -> [-1, 1]
# resize to latent feature map size
# h, w = mask.shape[:2]
# mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
# mask = norm_img(mask)
#
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
# import time
# time.sleep(2)
# return image
seed = config.sd_seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
output = self.model(
prompt=config.prompt,
init_image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
strength=config.sd_strength,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np.array",
callbacks=self.callbacks,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
img_h, img_w = image.shape[:2]
# boxes = boxes_from_mask(mask)
if config.use_croper:
logger.info("use croper")
l, t, w, h = (
config.croper_x,
config.croper_y,
config.croper_width,
config.croper_height,
)
r = l + w
b = t + h
l = max(l, 0)
r = min(r, img_w)
t = max(t, 0)
b = min(b, img_h)
crop_img = image[t:b, l:r, :]
crop_mask = mask[t:b, l:r]
crop_image = self._pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
else:
inpaint_result = self._pad_forward(image, mask, config)
return inpaint_result
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
class SD14(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-4"
class SD15(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-5"

View File

@ -0,0 +1,309 @@
import inspect
from typing import List, Optional, Union, Callable
import numpy as np
import torch
import PIL
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
from diffusers.utils import logging
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__)
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
class StableDiffusionInpaintPipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callbacks: List[Callable[[int], None]] = None
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
converted to a single channel (luminance) before use.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# preprocess image
init_image = preprocess_image(init_image).to(self.device)
# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size)
init_latents_orig = init_latents
# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = (init_latents_proper * mask) + (latents * (1 - mask))
if callbacks is not None:
for callback in callbacks:
callback(i)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -216,7 +216,7 @@ class ZITS(InpaintModel):
self.device = device
self.sample_edge_line_iterations = 1
def init_model(self, device):
def init_model(self, device, **kwargs):
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
self.structure_upsample = load_jit_model(

View File

@ -2,27 +2,23 @@ from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.mat import MAT
from lama_cleaner.model.sd import SD14
from lama_cleaner.model.zits import ZITS
from lama_cleaner.schema import Config
models = {
'lama': LaMa,
'ldm': LDM,
'zits': ZITS,
'mat': MAT,
'fcf': FcF
}
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14}
class ModelManager:
def __init__(self, name: str, device):
def __init__(self, name: str, device, **kwargs):
self.name = name
self.device = device
self.model = self.init_model(name, device)
self.kwargs = kwargs
self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device):
def init_model(self, name: str, device, **kwargs):
if name in models:
model = models[name](device)
model = models[name](device, **kwargs)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
@ -40,7 +36,7 @@ class ModelManager:
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.model = self.init_model(new_name, self.device, **self.kwargs)
self.name = new_name
except NotImplementedError as e:
raise e

View File

@ -7,7 +7,16 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm", "zits", "mat", 'fcf'])
parser.add_argument(
"--model",
default="lama",
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.4"],
)
parser.add_argument(
"--hf_access_token",
default="",
help="huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens",
)
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
@ -29,4 +38,10 @@ def parse_args():
if imghdr.what(args.input) is None:
parser.error(f"invalid --input: {args.input} is not a valid image file")
if args.model.startswith("sd"):
if not args.hf_access_token.startswith("hf_"):
parser.error(
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"
)
return args

View File

@ -22,3 +22,19 @@ class Config(BaseModel):
hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int
hd_strategy_resize_limit: int
prompt: str = ''
# 始终是在原图尺度上的值
use_croper: bool = False
croper_x: int = None
croper_y: int = None
croper_height: int = None
croper_width: int = None
# sd
sd_strength: float = 0.75
sd_steps: int = 50
sd_guidance_scale: float = 7.5
sd_sampler: str = 'ddim' # ddim/pndm
# -1 mean random seed
sd_seed: int = 42

View File

@ -4,6 +4,7 @@ import io
import logging
import multiprocessing
import os
import random
import time
import imghdr
from pathlib import Path
@ -41,7 +42,7 @@ from lama_cleaner.helper import (
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"]="True"
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
@ -64,6 +65,10 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"])
# MAX_BUFFER_SIZE = 50 * 1000 * 1000 # 50 MB
# async_mode 优先级: eventlet/gevent_uwsgi/gevent/threading
# only threading works on macOS
# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading')
model: ModelManager = None
device = None
@ -77,6 +82,11 @@ def get_image_ext(img_bytes):
return w
def diffuser_callback(step: int):
pass
# socketio.emit('diffusion_step', {'diffusion_step': step})
@app.route("/inpaint", methods=["POST"])
def process():
input = request.files
@ -102,8 +112,24 @@ def process():
hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],
hd_strategy_resize_limit=form["hdStrategyResizeLimit"],
prompt=form['prompt'],
use_croper=form['useCroper'],
croper_x=form['croperX'],
croper_y=form['croperY'],
croper_height=form['croperHeight'],
croper_width=form['croperWidth'],
sd_strength=form["sdStrength"],
sd_steps=form["sdSteps"],
sd_guidance_scale=form["sdGuidanceScale"],
sd_sampler=form["sdSampler"],
sd_seed=form["sdSeed"],
)
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 9999999)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}")
@ -184,7 +210,8 @@ def main(args):
device = torch.device(args.device)
input_image_path = args.input
model = ModelManager(name=args.model, device=device)
model = ModelManager(name=args.model, device=device, hf_access_token=args.hf_access_token,
callbacks=[diffuser_callback])
if args.gui:
app_width, app_height = args.gui_size
@ -195,4 +222,5 @@ def main(args):
)
ui.run()
else:
# TODO: socketio
app.run(host=args.host, port=args.port, debug=args.debug)

Binary file not shown.

After

Width:  |  Height:  |  Size: 395 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -1,3 +1,4 @@
import os
from pathlib import Path
import cv2
@ -11,10 +12,10 @@ current_dir = Path(__file__).parent.absolute().resolve()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_data(fx=1, fy=1.0):
img = cv2.imread(str(current_dir / "image.png"))
def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
img = cv2.imread(str(img_p))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE)
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
if fx != 1:
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
@ -35,8 +36,8 @@ def get_config(strategy, **kwargs):
return Config(**data)
def assert_equal(model, config, gt_name, fx=1, fy=1):
img, mask = get_data(fx=fx, fy=fy)
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
res = model(img, mask, config)
cv2.imwrite(
str(current_dir / gt_name),
@ -153,3 +154,26 @@ def test_fcf(strategy):
fx=3.8,
fy=2
)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_sd(strategy, capfd):
def callback(step: int):
print(f"sd_step_{step}")
sd_steps = 2
model = ModelManager(name="sd", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], callbacks=[callback])
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=0.5,
fy=0.5
)
captured = capfd.readouterr()
for i in range(sd_steps):
assert f'sd_step_{i}' in captured.out

View File

@ -10,3 +10,5 @@ pytest
yacs
markupsafe==2.0.1
scikit-image==0.19.3
diffusers==0.3.0
transformers==4.20.0