add custom mask upload, WIP, need more test
better handle server error
This commit is contained in:
parent
0666a32947
commit
eec41734c3
@ -5,19 +5,23 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
|||||||
|
|
||||||
export default async function inpaint(
|
export default async function inpaint(
|
||||||
imageFile: File,
|
imageFile: File,
|
||||||
maskBase64: string,
|
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
croperRect: Rect,
|
croperRect: Rect,
|
||||||
prompt?: string,
|
prompt?: string,
|
||||||
negativePrompt?: string,
|
negativePrompt?: string,
|
||||||
sizeLimit?: string,
|
sizeLimit?: string,
|
||||||
seed?: number
|
seed?: number,
|
||||||
|
maskBase64?: string,
|
||||||
|
customMask?: File
|
||||||
) {
|
) {
|
||||||
// 1080, 2000, Original
|
// 1080, 2000, Original
|
||||||
const fd = new FormData()
|
const fd = new FormData()
|
||||||
fd.append('image', imageFile)
|
fd.append('image', imageFile)
|
||||||
const mask = dataURItoBlob(maskBase64)
|
if (maskBase64 !== undefined) {
|
||||||
fd.append('mask', mask)
|
fd.append('mask', dataURItoBlob(maskBase64))
|
||||||
|
} else if (customMask !== undefined) {
|
||||||
|
fd.append('mask', customMask)
|
||||||
|
}
|
||||||
|
|
||||||
const hdSettings = settings.hdSettings[settings.model]
|
const hdSettings = settings.hdSettings[settings.model]
|
||||||
fd.append('ldmSteps', settings.ldmSteps.toString())
|
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||||
@ -70,8 +74,10 @@ export default async function inpaint(
|
|||||||
const newSeed = res.headers.get('x-seed')
|
const newSeed = res.headers.get('x-seed')
|
||||||
return { blob: URL.createObjectURL(blob), seed: newSeed }
|
return { blob: URL.createObjectURL(blob), seed: newSeed }
|
||||||
}
|
}
|
||||||
} catch {
|
const errMsg = await res.text()
|
||||||
throw new Error('Something went wrong on server side.')
|
throw new Error(errMsg)
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Something went wrong: ${error}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,7 +45,11 @@ import {
|
|||||||
} from '../../store/Atoms'
|
} from '../../store/Atoms'
|
||||||
import useHotKey from '../../hooks/useHotkey'
|
import useHotKey from '../../hooks/useHotkey'
|
||||||
import Croper from '../Croper/Croper'
|
import Croper from '../Croper/Croper'
|
||||||
import emitter, { EVENT_PROMPT } from '../../event'
|
import emitter, {
|
||||||
|
EVENT_PROMPT,
|
||||||
|
EVENT_CUSTOM_MASK,
|
||||||
|
CustomMaskEventData,
|
||||||
|
} from '../../event'
|
||||||
import FileSelect from '../FileSelect/FileSelect'
|
import FileSelect from '../FileSelect/FileSelect'
|
||||||
|
|
||||||
const TOOLBAR_SIZE = 200
|
const TOOLBAR_SIZE = 200
|
||||||
@ -195,22 +199,23 @@ export default function Editor() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
const runInpainting = useCallback(
|
const runInpainting = useCallback(
|
||||||
async (prompt?: string, useLastLineGroup?: boolean) => {
|
async (useLastLineGroup?: boolean, customMask?: File) => {
|
||||||
if (file === undefined) {
|
if (file === undefined) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
const useCustomMask = customMask !== undefined
|
||||||
// useLastLineGroup 的影响
|
// useLastLineGroup 的影响
|
||||||
// 1. 使用上一次的 mask
|
// 1. 使用上一次的 mask
|
||||||
// 2. 结果替换当前 render
|
// 2. 结果替换当前 render
|
||||||
console.log('runInpainting')
|
console.log('runInpainting')
|
||||||
|
|
||||||
let maskLineGroup = []
|
let maskLineGroup: LineGroup = []
|
||||||
if (useLastLineGroup === true) {
|
if (useLastLineGroup === true) {
|
||||||
if (lastLineGroup.length === 0) {
|
if (lastLineGroup.length === 0) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
maskLineGroup = lastLineGroup
|
maskLineGroup = lastLineGroup
|
||||||
} else {
|
} else if (!useCustomMask) {
|
||||||
if (!hadDrawSomething()) {
|
if (!hadDrawSomething()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -256,23 +261,23 @@ export default function Editor() {
|
|||||||
|
|
||||||
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
|
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
|
||||||
|
|
||||||
|
console.log({ useCustomMask })
|
||||||
try {
|
try {
|
||||||
const res = await inpaint(
|
const res = await inpaint(
|
||||||
targetFile,
|
targetFile,
|
||||||
maskCanvas.toDataURL(),
|
|
||||||
settings,
|
settings,
|
||||||
croperRect,
|
croperRect,
|
||||||
prompt,
|
promptVal,
|
||||||
negativePromptVal,
|
negativePromptVal,
|
||||||
sizeLimit.toString(),
|
sizeLimit.toString(),
|
||||||
sdSeed
|
sdSeed,
|
||||||
|
useCustomMask ? undefined : maskCanvas.toDataURL(),
|
||||||
|
useCustomMask ? customMask : undefined
|
||||||
)
|
)
|
||||||
if (!res) {
|
if (!res) {
|
||||||
throw new Error('empty response')
|
throw new Error('Something went wrong on server side.')
|
||||||
}
|
}
|
||||||
const { blob, seed } = res
|
const { blob, seed } = res
|
||||||
console.log(seed)
|
|
||||||
console.log(settings.sdSeedFixed)
|
|
||||||
if (seed && !settings.sdSeedFixed) {
|
if (seed && !settings.sdSeedFixed) {
|
||||||
setSeed(parseInt(seed, 10))
|
setSeed(parseInt(seed, 10))
|
||||||
}
|
}
|
||||||
@ -324,9 +329,9 @@ export default function Editor() {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
emitter.on(EVENT_PROMPT, () => {
|
emitter.on(EVENT_PROMPT, () => {
|
||||||
if (hadDrawSomething()) {
|
if (hadDrawSomething()) {
|
||||||
runInpainting(promptVal)
|
runInpainting()
|
||||||
} else if (lastLineGroup.length !== 0) {
|
} else if (lastLineGroup.length !== 0) {
|
||||||
runInpainting(promptVal, true)
|
runInpainting(true)
|
||||||
} else {
|
} else {
|
||||||
setToastState({
|
setToastState({
|
||||||
open: true,
|
open: true,
|
||||||
@ -336,10 +341,21 @@ export default function Editor() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
emitter.off(EVENT_PROMPT)
|
emitter.off(EVENT_PROMPT)
|
||||||
}
|
}
|
||||||
}, [hadDrawSomething, runInpainting, prompt])
|
}, [hadDrawSomething, runInpainting, promptVal])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
|
||||||
|
runInpainting(false, data.mask)
|
||||||
|
})
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
emitter.off(EVENT_CUSTOM_MASK)
|
||||||
|
}
|
||||||
|
}, [runInpainting])
|
||||||
|
|
||||||
const hadRunInpainting = () => {
|
const hadRunInpainting = () => {
|
||||||
return renders.length !== 0
|
return renders.length !== 0
|
||||||
|
@ -1,25 +1,33 @@
|
|||||||
import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline'
|
import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline'
|
||||||
import React, { useState } from 'react'
|
import React, { useState } from 'react'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import { fileState, isSDState } from '../../store/Atoms'
|
import { fileState, isInpaintingState, isSDState } from '../../store/Atoms'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
import Shortcuts from '../Shortcuts/Shortcuts'
|
import Shortcuts from '../Shortcuts/Shortcuts'
|
||||||
import useResolution from '../../hooks/useResolution'
|
|
||||||
import { ThemeChanger } from './ThemeChanger'
|
import { ThemeChanger } from './ThemeChanger'
|
||||||
import SettingIcon from '../Settings/SettingIcon'
|
import SettingIcon from '../Settings/SettingIcon'
|
||||||
import PromptInput from './PromptInput'
|
import PromptInput from './PromptInput'
|
||||||
import CoffeeIcon from '../CoffeeIcon/CoffeeIcon'
|
import CoffeeIcon from '../CoffeeIcon/CoffeeIcon'
|
||||||
|
import emitter, { EVENT_CUSTOM_MASK } from '../../event'
|
||||||
|
|
||||||
const Header = () => {
|
const Header = () => {
|
||||||
|
const isInpainting = useRecoilValue(isInpaintingState)
|
||||||
const [file, setFile] = useRecoilState(fileState)
|
const [file, setFile] = useRecoilState(fileState)
|
||||||
const resolution = useResolution()
|
|
||||||
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
|
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
|
||||||
|
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
|
||||||
const isSD = useRecoilValue(isSDState)
|
const isSD = useRecoilValue(isSDState)
|
||||||
|
|
||||||
const renderHeader = () => {
|
const renderHeader = () => {
|
||||||
return (
|
return (
|
||||||
<header>
|
<header>
|
||||||
<div>
|
<div
|
||||||
|
style={{
|
||||||
|
display: 'flex',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 8,
|
||||||
|
}}
|
||||||
|
>
|
||||||
<label htmlFor={uploadElemId}>
|
<label htmlFor={uploadElemId}>
|
||||||
<Button icon={<UploadIcon />} style={{ border: 0 }}>
|
<Button icon={<UploadIcon />} style={{ border: 0 }}>
|
||||||
<input
|
<input
|
||||||
@ -35,7 +43,35 @@ const Header = () => {
|
|||||||
}}
|
}}
|
||||||
accept="image/png, image/jpeg"
|
accept="image/png, image/jpeg"
|
||||||
/>
|
/>
|
||||||
{resolution === 'desktop' ? 'Upload New' : undefined}
|
Image
|
||||||
|
</Button>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<label
|
||||||
|
htmlFor={maskUploadElemId}
|
||||||
|
style={{ visibility: file ? 'visible' : 'hidden' }}
|
||||||
|
>
|
||||||
|
<Button style={{ border: 0 }} disabled={isInpainting}>
|
||||||
|
<input
|
||||||
|
style={{ display: 'none' }}
|
||||||
|
id={maskUploadElemId}
|
||||||
|
name={maskUploadElemId}
|
||||||
|
type="file"
|
||||||
|
onClick={e => {
|
||||||
|
const element = e.target as HTMLInputElement
|
||||||
|
element.value = ''
|
||||||
|
}}
|
||||||
|
onChange={ev => {
|
||||||
|
const newFile = ev.currentTarget.files?.[0]
|
||||||
|
if (newFile) {
|
||||||
|
// TODO: check mask size
|
||||||
|
console.info('Send custom mask')
|
||||||
|
emitter.emit(EVENT_CUSTOM_MASK, { mask: newFile })
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
accept="image/png, image/jpeg"
|
||||||
|
/>
|
||||||
|
Mask
|
||||||
</Button>
|
</Button>
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import mitt from 'mitt'
|
import mitt from 'mitt'
|
||||||
|
|
||||||
export const EVENT_PROMPT = 'prompt'
|
export const EVENT_PROMPT = 'prompt'
|
||||||
|
export const EVENT_CUSTOM_MASK = 'custom_mask'
|
||||||
|
export interface CustomMaskEventData {
|
||||||
|
mask: File
|
||||||
|
}
|
||||||
|
|
||||||
const emitter = mitt()
|
const emitter = mitt()
|
||||||
|
|
||||||
|
@ -94,6 +94,10 @@ def process():
|
|||||||
origin_image_bytes = input["image"].read()
|
origin_image_bytes = input["image"].read()
|
||||||
|
|
||||||
image, alpha_channel = load_img(origin_image_bytes)
|
image, alpha_channel = load_img(origin_image_bytes)
|
||||||
|
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||||
|
if image.shape[:2] != mask.shape[:2]:
|
||||||
|
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
||||||
|
|
||||||
original_shape = image.shape
|
original_shape = image.shape
|
||||||
interpolation = cv2.INTER_CUBIC
|
interpolation = cv2.INTER_CUBIC
|
||||||
|
|
||||||
@ -136,14 +140,18 @@ def process():
|
|||||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||||
logger.info(f"Resized image shape: {image.shape}")
|
logger.info(f"Resized image shape: {image.shape}")
|
||||||
|
|
||||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
|
||||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
res_np_img = model(image, mask, config)
|
try:
|
||||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
res_np_img = model(image, mask, config)
|
||||||
|
except RuntimeError as e:
|
||||||
torch.cuda.empty_cache()
|
# NOTE: the string may change?
|
||||||
|
if "CUDA out of memory. " in str(e):
|
||||||
|
return "CUDA out of memory", 500
|
||||||
|
finally:
|
||||||
|
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if alpha_channel is not None:
|
if alpha_channel is not None:
|
||||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||||
|
Loading…
Reference in New Issue
Block a user