add custom mask upload, WIP, need more test

better handle server error
This commit is contained in:
Qing 2022-11-13 23:31:11 +08:00
parent 0666a32947
commit eec41734c3
5 changed files with 99 additions and 29 deletions

View File

@ -5,19 +5,23 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
export default async function inpaint( export default async function inpaint(
imageFile: File, imageFile: File,
maskBase64: string,
settings: Settings, settings: Settings,
croperRect: Rect, croperRect: Rect,
prompt?: string, prompt?: string,
negativePrompt?: string, negativePrompt?: string,
sizeLimit?: string, sizeLimit?: string,
seed?: number seed?: number,
maskBase64?: string,
customMask?: File
) { ) {
// 1080, 2000, Original // 1080, 2000, Original
const fd = new FormData() const fd = new FormData()
fd.append('image', imageFile) fd.append('image', imageFile)
const mask = dataURItoBlob(maskBase64) if (maskBase64 !== undefined) {
fd.append('mask', mask) fd.append('mask', dataURItoBlob(maskBase64))
} else if (customMask !== undefined) {
fd.append('mask', customMask)
}
const hdSettings = settings.hdSettings[settings.model] const hdSettings = settings.hdSettings[settings.model]
fd.append('ldmSteps', settings.ldmSteps.toString()) fd.append('ldmSteps', settings.ldmSteps.toString())
@ -70,8 +74,10 @@ export default async function inpaint(
const newSeed = res.headers.get('x-seed') const newSeed = res.headers.get('x-seed')
return { blob: URL.createObjectURL(blob), seed: newSeed } return { blob: URL.createObjectURL(blob), seed: newSeed }
} }
} catch { const errMsg = await res.text()
throw new Error('Something went wrong on server side.') throw new Error(errMsg)
} catch (error) {
throw new Error(`Something went wrong: ${error}`)
} }
} }

View File

@ -45,7 +45,11 @@ import {
} from '../../store/Atoms' } from '../../store/Atoms'
import useHotKey from '../../hooks/useHotkey' import useHotKey from '../../hooks/useHotkey'
import Croper from '../Croper/Croper' import Croper from '../Croper/Croper'
import emitter, { EVENT_PROMPT } from '../../event' import emitter, {
EVENT_PROMPT,
EVENT_CUSTOM_MASK,
CustomMaskEventData,
} from '../../event'
import FileSelect from '../FileSelect/FileSelect' import FileSelect from '../FileSelect/FileSelect'
const TOOLBAR_SIZE = 200 const TOOLBAR_SIZE = 200
@ -195,22 +199,23 @@ export default function Editor() {
) )
const runInpainting = useCallback( const runInpainting = useCallback(
async (prompt?: string, useLastLineGroup?: boolean) => { async (useLastLineGroup?: boolean, customMask?: File) => {
if (file === undefined) { if (file === undefined) {
return return
} }
const useCustomMask = customMask !== undefined
// useLastLineGroup 的影响 // useLastLineGroup 的影响
// 1. 使用上一次的 mask // 1. 使用上一次的 mask
// 2. 结果替换当前 render // 2. 结果替换当前 render
console.log('runInpainting') console.log('runInpainting')
let maskLineGroup = [] let maskLineGroup: LineGroup = []
if (useLastLineGroup === true) { if (useLastLineGroup === true) {
if (lastLineGroup.length === 0) { if (lastLineGroup.length === 0) {
return return
} }
maskLineGroup = lastLineGroup maskLineGroup = lastLineGroup
} else { } else if (!useCustomMask) {
if (!hadDrawSomething()) { if (!hadDrawSomething()) {
return return
} }
@ -256,23 +261,23 @@ export default function Editor() {
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1 const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
console.log({ useCustomMask })
try { try {
const res = await inpaint( const res = await inpaint(
targetFile, targetFile,
maskCanvas.toDataURL(),
settings, settings,
croperRect, croperRect,
prompt, promptVal,
negativePromptVal, negativePromptVal,
sizeLimit.toString(), sizeLimit.toString(),
sdSeed sdSeed,
useCustomMask ? undefined : maskCanvas.toDataURL(),
useCustomMask ? customMask : undefined
) )
if (!res) { if (!res) {
throw new Error('empty response') throw new Error('Something went wrong on server side.')
} }
const { blob, seed } = res const { blob, seed } = res
console.log(seed)
console.log(settings.sdSeedFixed)
if (seed && !settings.sdSeedFixed) { if (seed && !settings.sdSeedFixed) {
setSeed(parseInt(seed, 10)) setSeed(parseInt(seed, 10))
} }
@ -324,9 +329,9 @@ export default function Editor() {
useEffect(() => { useEffect(() => {
emitter.on(EVENT_PROMPT, () => { emitter.on(EVENT_PROMPT, () => {
if (hadDrawSomething()) { if (hadDrawSomething()) {
runInpainting(promptVal) runInpainting()
} else if (lastLineGroup.length !== 0) { } else if (lastLineGroup.length !== 0) {
runInpainting(promptVal, true) runInpainting(true)
} else { } else {
setToastState({ setToastState({
open: true, open: true,
@ -336,10 +341,21 @@ export default function Editor() {
}) })
} }
}) })
return () => { return () => {
emitter.off(EVENT_PROMPT) emitter.off(EVENT_PROMPT)
} }
}, [hadDrawSomething, runInpainting, prompt]) }, [hadDrawSomething, runInpainting, promptVal])
useEffect(() => {
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
runInpainting(false, data.mask)
})
return () => {
emitter.off(EVENT_CUSTOM_MASK)
}
}, [runInpainting])
const hadRunInpainting = () => { const hadRunInpainting = () => {
return renders.length !== 0 return renders.length !== 0

View File

@ -1,25 +1,33 @@
import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline' import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline'
import React, { useState } from 'react' import React, { useState } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import { fileState, isSDState } from '../../store/Atoms' import { fileState, isInpaintingState, isSDState } from '../../store/Atoms'
import Button from '../shared/Button' import Button from '../shared/Button'
import Shortcuts from '../Shortcuts/Shortcuts' import Shortcuts from '../Shortcuts/Shortcuts'
import useResolution from '../../hooks/useResolution'
import { ThemeChanger } from './ThemeChanger' import { ThemeChanger } from './ThemeChanger'
import SettingIcon from '../Settings/SettingIcon' import SettingIcon from '../Settings/SettingIcon'
import PromptInput from './PromptInput' import PromptInput from './PromptInput'
import CoffeeIcon from '../CoffeeIcon/CoffeeIcon' import CoffeeIcon from '../CoffeeIcon/CoffeeIcon'
import emitter, { EVENT_CUSTOM_MASK } from '../../event'
const Header = () => { const Header = () => {
const isInpainting = useRecoilValue(isInpaintingState)
const [file, setFile] = useRecoilState(fileState) const [file, setFile] = useRecoilState(fileState)
const resolution = useResolution()
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`) const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
const isSD = useRecoilValue(isSDState) const isSD = useRecoilValue(isSDState)
const renderHeader = () => { const renderHeader = () => {
return ( return (
<header> <header>
<div> <div
style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
gap: 8,
}}
>
<label htmlFor={uploadElemId}> <label htmlFor={uploadElemId}>
<Button icon={<UploadIcon />} style={{ border: 0 }}> <Button icon={<UploadIcon />} style={{ border: 0 }}>
<input <input
@ -35,7 +43,35 @@ const Header = () => {
}} }}
accept="image/png, image/jpeg" accept="image/png, image/jpeg"
/> />
{resolution === 'desktop' ? 'Upload New' : undefined} Image
</Button>
</label>
<label
htmlFor={maskUploadElemId}
style={{ visibility: file ? 'visible' : 'hidden' }}
>
<Button style={{ border: 0 }} disabled={isInpainting}>
<input
style={{ display: 'none' }}
id={maskUploadElemId}
name={maskUploadElemId}
type="file"
onClick={e => {
const element = e.target as HTMLInputElement
element.value = ''
}}
onChange={ev => {
const newFile = ev.currentTarget.files?.[0]
if (newFile) {
// TODO: check mask size
console.info('Send custom mask')
emitter.emit(EVENT_CUSTOM_MASK, { mask: newFile })
}
}}
accept="image/png, image/jpeg"
/>
Mask
</Button> </Button>
</label> </label>
</div> </div>

View File

@ -1,6 +1,10 @@
import mitt from 'mitt' import mitt from 'mitt'
export const EVENT_PROMPT = 'prompt' export const EVENT_PROMPT = 'prompt'
export const EVENT_CUSTOM_MASK = 'custom_mask'
export interface CustomMaskEventData {
mask: File
}
const emitter = mitt() const emitter = mitt()

View File

@ -94,6 +94,10 @@ def process():
origin_image_bytes = input["image"].read() origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes) image, alpha_channel = load_img(origin_image_bytes)
mask, _ = load_img(input["mask"].read(), gray=True)
if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
original_shape = image.shape original_shape = image.shape
interpolation = cv2.INTER_CUBIC interpolation = cv2.INTER_CUBIC
@ -136,14 +140,18 @@ def process():
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
logger.info(f"Resized image shape: {image.shape}") logger.info(f"Resized image shape: {image.shape}")
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
start = time.time() start = time.time()
res_np_img = model(image, mask, config) try:
logger.info(f"process time: {(time.time() - start) * 1000}ms") res_np_img = model(image, mask, config)
except RuntimeError as e:
torch.cuda.empty_cache() # NOTE: the string may change?
if "CUDA out of memory. " in str(e):
return "CUDA out of memory", 500
finally:
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
if alpha_channel is not None: if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]: if alpha_channel.shape[:2] != res_np_img.shape[:2]: