add custom mask upload, WIP, need more test
better handle server error
This commit is contained in:
parent
0666a32947
commit
eec41734c3
@ -5,19 +5,23 @@ export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
|
||||
export default async function inpaint(
|
||||
imageFile: File,
|
||||
maskBase64: string,
|
||||
settings: Settings,
|
||||
croperRect: Rect,
|
||||
prompt?: string,
|
||||
negativePrompt?: string,
|
||||
sizeLimit?: string,
|
||||
seed?: number
|
||||
seed?: number,
|
||||
maskBase64?: string,
|
||||
customMask?: File
|
||||
) {
|
||||
// 1080, 2000, Original
|
||||
const fd = new FormData()
|
||||
fd.append('image', imageFile)
|
||||
const mask = dataURItoBlob(maskBase64)
|
||||
fd.append('mask', mask)
|
||||
if (maskBase64 !== undefined) {
|
||||
fd.append('mask', dataURItoBlob(maskBase64))
|
||||
} else if (customMask !== undefined) {
|
||||
fd.append('mask', customMask)
|
||||
}
|
||||
|
||||
const hdSettings = settings.hdSettings[settings.model]
|
||||
fd.append('ldmSteps', settings.ldmSteps.toString())
|
||||
@ -70,8 +74,10 @@ export default async function inpaint(
|
||||
const newSeed = res.headers.get('x-seed')
|
||||
return { blob: URL.createObjectURL(blob), seed: newSeed }
|
||||
}
|
||||
} catch {
|
||||
throw new Error('Something went wrong on server side.')
|
||||
const errMsg = await res.text()
|
||||
throw new Error(errMsg)
|
||||
} catch (error) {
|
||||
throw new Error(`Something went wrong: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,11 @@ import {
|
||||
} from '../../store/Atoms'
|
||||
import useHotKey from '../../hooks/useHotkey'
|
||||
import Croper from '../Croper/Croper'
|
||||
import emitter, { EVENT_PROMPT } from '../../event'
|
||||
import emitter, {
|
||||
EVENT_PROMPT,
|
||||
EVENT_CUSTOM_MASK,
|
||||
CustomMaskEventData,
|
||||
} from '../../event'
|
||||
import FileSelect from '../FileSelect/FileSelect'
|
||||
|
||||
const TOOLBAR_SIZE = 200
|
||||
@ -195,22 +199,23 @@ export default function Editor() {
|
||||
)
|
||||
|
||||
const runInpainting = useCallback(
|
||||
async (prompt?: string, useLastLineGroup?: boolean) => {
|
||||
async (useLastLineGroup?: boolean, customMask?: File) => {
|
||||
if (file === undefined) {
|
||||
return
|
||||
}
|
||||
const useCustomMask = customMask !== undefined
|
||||
// useLastLineGroup 的影响
|
||||
// 1. 使用上一次的 mask
|
||||
// 2. 结果替换当前 render
|
||||
console.log('runInpainting')
|
||||
|
||||
let maskLineGroup = []
|
||||
let maskLineGroup: LineGroup = []
|
||||
if (useLastLineGroup === true) {
|
||||
if (lastLineGroup.length === 0) {
|
||||
return
|
||||
}
|
||||
maskLineGroup = lastLineGroup
|
||||
} else {
|
||||
} else if (!useCustomMask) {
|
||||
if (!hadDrawSomething()) {
|
||||
return
|
||||
}
|
||||
@ -256,23 +261,23 @@ export default function Editor() {
|
||||
|
||||
const sdSeed = settings.sdSeedFixed ? settings.sdSeed : -1
|
||||
|
||||
console.log({ useCustomMask })
|
||||
try {
|
||||
const res = await inpaint(
|
||||
targetFile,
|
||||
maskCanvas.toDataURL(),
|
||||
settings,
|
||||
croperRect,
|
||||
prompt,
|
||||
promptVal,
|
||||
negativePromptVal,
|
||||
sizeLimit.toString(),
|
||||
sdSeed
|
||||
sdSeed,
|
||||
useCustomMask ? undefined : maskCanvas.toDataURL(),
|
||||
useCustomMask ? customMask : undefined
|
||||
)
|
||||
if (!res) {
|
||||
throw new Error('empty response')
|
||||
throw new Error('Something went wrong on server side.')
|
||||
}
|
||||
const { blob, seed } = res
|
||||
console.log(seed)
|
||||
console.log(settings.sdSeedFixed)
|
||||
if (seed && !settings.sdSeedFixed) {
|
||||
setSeed(parseInt(seed, 10))
|
||||
}
|
||||
@ -324,9 +329,9 @@ export default function Editor() {
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_PROMPT, () => {
|
||||
if (hadDrawSomething()) {
|
||||
runInpainting(promptVal)
|
||||
runInpainting()
|
||||
} else if (lastLineGroup.length !== 0) {
|
||||
runInpainting(promptVal, true)
|
||||
runInpainting(true)
|
||||
} else {
|
||||
setToastState({
|
||||
open: true,
|
||||
@ -336,10 +341,21 @@ export default function Editor() {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return () => {
|
||||
emitter.off(EVENT_PROMPT)
|
||||
}
|
||||
}, [hadDrawSomething, runInpainting, prompt])
|
||||
}, [hadDrawSomething, runInpainting, promptVal])
|
||||
|
||||
useEffect(() => {
|
||||
emitter.on(EVENT_CUSTOM_MASK, (data: any) => {
|
||||
runInpainting(false, data.mask)
|
||||
})
|
||||
|
||||
return () => {
|
||||
emitter.off(EVENT_CUSTOM_MASK)
|
||||
}
|
||||
}, [runInpainting])
|
||||
|
||||
const hadRunInpainting = () => {
|
||||
return renders.length !== 0
|
||||
|
@ -1,25 +1,33 @@
|
||||
import { ArrowLeftIcon, UploadIcon } from '@heroicons/react/outline'
|
||||
import React, { useState } from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import { fileState, isSDState } from '../../store/Atoms'
|
||||
import { fileState, isInpaintingState, isSDState } from '../../store/Atoms'
|
||||
import Button from '../shared/Button'
|
||||
import Shortcuts from '../Shortcuts/Shortcuts'
|
||||
import useResolution from '../../hooks/useResolution'
|
||||
import { ThemeChanger } from './ThemeChanger'
|
||||
import SettingIcon from '../Settings/SettingIcon'
|
||||
import PromptInput from './PromptInput'
|
||||
import CoffeeIcon from '../CoffeeIcon/CoffeeIcon'
|
||||
import emitter, { EVENT_CUSTOM_MASK } from '../../event'
|
||||
|
||||
const Header = () => {
|
||||
const isInpainting = useRecoilValue(isInpaintingState)
|
||||
const [file, setFile] = useRecoilState(fileState)
|
||||
const resolution = useResolution()
|
||||
const [uploadElemId] = useState(`file-upload-${Math.random().toString()}`)
|
||||
const [maskUploadElemId] = useState(`mask-upload-${Math.random().toString()}`)
|
||||
const isSD = useRecoilValue(isSDState)
|
||||
|
||||
const renderHeader = () => {
|
||||
return (
|
||||
<header>
|
||||
<div>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
gap: 8,
|
||||
}}
|
||||
>
|
||||
<label htmlFor={uploadElemId}>
|
||||
<Button icon={<UploadIcon />} style={{ border: 0 }}>
|
||||
<input
|
||||
@ -35,7 +43,35 @@ const Header = () => {
|
||||
}}
|
||||
accept="image/png, image/jpeg"
|
||||
/>
|
||||
{resolution === 'desktop' ? 'Upload New' : undefined}
|
||||
Image
|
||||
</Button>
|
||||
</label>
|
||||
|
||||
<label
|
||||
htmlFor={maskUploadElemId}
|
||||
style={{ visibility: file ? 'visible' : 'hidden' }}
|
||||
>
|
||||
<Button style={{ border: 0 }} disabled={isInpainting}>
|
||||
<input
|
||||
style={{ display: 'none' }}
|
||||
id={maskUploadElemId}
|
||||
name={maskUploadElemId}
|
||||
type="file"
|
||||
onClick={e => {
|
||||
const element = e.target as HTMLInputElement
|
||||
element.value = ''
|
||||
}}
|
||||
onChange={ev => {
|
||||
const newFile = ev.currentTarget.files?.[0]
|
||||
if (newFile) {
|
||||
// TODO: check mask size
|
||||
console.info('Send custom mask')
|
||||
emitter.emit(EVENT_CUSTOM_MASK, { mask: newFile })
|
||||
}
|
||||
}}
|
||||
accept="image/png, image/jpeg"
|
||||
/>
|
||||
Mask
|
||||
</Button>
|
||||
</label>
|
||||
</div>
|
||||
|
@ -1,6 +1,10 @@
|
||||
import mitt from 'mitt'
|
||||
|
||||
export const EVENT_PROMPT = 'prompt'
|
||||
export const EVENT_CUSTOM_MASK = 'custom_mask'
|
||||
export interface CustomMaskEventData {
|
||||
mask: File
|
||||
}
|
||||
|
||||
const emitter = mitt()
|
||||
|
||||
|
@ -94,6 +94,10 @@ def process():
|
||||
origin_image_bytes = input["image"].read()
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
||||
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
@ -136,13 +140,17 @@ def process():
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
res_np_img = model(image, mask, config)
|
||||
except RuntimeError as e:
|
||||
# NOTE: the string may change?
|
||||
if "CUDA out of memory. " in str(e):
|
||||
return "CUDA out of memory", 500
|
||||
finally:
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if alpha_channel is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user