lots of updates
This commit is contained in:
parent
2e8e52f7a5
commit
a22536becc
@ -1,8 +1,7 @@
|
|||||||
import React, { useCallback, useEffect, useMemo, useState } from 'react'
|
import React, { useCallback, useEffect, useMemo } from 'react'
|
||||||
import { useRecoilState } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import { nanoid } from 'nanoid'
|
import { nanoid } from 'nanoid'
|
||||||
import useInputImage from './hooks/useInputImage'
|
import useInputImage from './hooks/useInputImage'
|
||||||
import LandingPage from './components/LandingPage/LandingPage'
|
|
||||||
import { themeState } from './components/Header/ThemeChanger'
|
import { themeState } from './components/Header/ThemeChanger'
|
||||||
import Workspace from './components/Workspace'
|
import Workspace from './components/Workspace'
|
||||||
import {
|
import {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { Rect, Settings } from '../store/Atoms'
|
import { Rect, Settings } from '../store/Atoms'
|
||||||
import { dataURItoBlob } from '../utils'
|
import { dataURItoBlob, srcToFile } from '../utils'
|
||||||
|
|
||||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||||
|
|
||||||
@ -57,6 +57,7 @@ export default async function inpaint(
|
|||||||
fd.append('sdSampler', settings.sdSampler.toString())
|
fd.append('sdSampler', settings.sdSampler.toString())
|
||||||
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
||||||
fd.append('sdMatchHistograms', settings.sdMatchHistograms ? 'true' : 'false')
|
fd.append('sdMatchHistograms', settings.sdMatchHistograms ? 'true' : 'false')
|
||||||
|
fd.append('sdScale', (settings.sdScale / 100).toString())
|
||||||
|
|
||||||
fd.append('cv2Radius', settings.cv2Radius.toString())
|
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||||
fd.append('cv2Flag', settings.cv2Flag.toString())
|
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||||
@ -198,3 +199,27 @@ export async function getMedias() {
|
|||||||
const errMsg = await res.text()
|
const errMsg = await res.text()
|
||||||
throw new Error(errMsg)
|
throw new Error(errMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function downloadToOutput(
|
||||||
|
image: HTMLImageElement,
|
||||||
|
filename: string,
|
||||||
|
mimeType: string
|
||||||
|
) {
|
||||||
|
const file = await srcToFile(image.src, filename, mimeType)
|
||||||
|
const fd = new FormData()
|
||||||
|
fd.append('image', file)
|
||||||
|
fd.append('filename', filename)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_ENDPOINT}/save_image`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: fd,
|
||||||
|
})
|
||||||
|
if (!res.ok) {
|
||||||
|
const errMsg = await res.text()
|
||||||
|
throw new Error(errMsg)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Something went wrong: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -18,7 +18,10 @@ import {
|
|||||||
} from 'react-zoom-pan-pinch'
|
} from 'react-zoom-pan-pinch'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||||
import inpaint, { postInteractiveSeg } from '../../adapters/inpainting'
|
import inpaint, {
|
||||||
|
downloadToOutput,
|
||||||
|
postInteractiveSeg,
|
||||||
|
} from '../../adapters/inpainting'
|
||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
import Slider from './Slider'
|
import Slider from './Slider'
|
||||||
import SizeSelector from './SizeSelector'
|
import SizeSelector from './SizeSelector'
|
||||||
@ -34,7 +37,10 @@ import {
|
|||||||
} from '../../utils'
|
} from '../../utils'
|
||||||
import {
|
import {
|
||||||
croperState,
|
croperState,
|
||||||
|
enableFileManagerState,
|
||||||
fileState,
|
fileState,
|
||||||
|
imageHeightState,
|
||||||
|
imageWidthState,
|
||||||
interactiveSegClicksState,
|
interactiveSegClicksState,
|
||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
isInteractiveSegRunningState,
|
isInteractiveSegRunningState,
|
||||||
@ -173,6 +179,10 @@ export default function Editor() {
|
|||||||
const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([])
|
const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([])
|
||||||
const [redoCurLines, setRedoCurLines] = useState<Line[]>([])
|
const [redoCurLines, setRedoCurLines] = useState<Line[]>([])
|
||||||
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
|
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
|
||||||
|
const enableFileManager = useRecoilValue(enableFileManagerState)
|
||||||
|
|
||||||
|
const [imageWidth, setImageWidth] = useRecoilState(imageWidthState)
|
||||||
|
const [imageHeight, setImageHeight] = useRecoilState(imageHeightState)
|
||||||
|
|
||||||
const draw = useCallback(
|
const draw = useCallback(
|
||||||
(render: HTMLImageElement, lineGroup: LineGroup) => {
|
(render: HTMLImageElement, lineGroup: LineGroup) => {
|
||||||
@ -524,6 +534,9 @@ export default function Editor() {
|
|||||||
const rW = windowSize.width / original.naturalWidth
|
const rW = windowSize.width / original.naturalWidth
|
||||||
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
|
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
|
||||||
|
|
||||||
|
setImageWidth(original.naturalWidth)
|
||||||
|
setImageHeight(original.naturalHeight)
|
||||||
|
|
||||||
let s = 1.0
|
let s = 1.0
|
||||||
if (rW < 1 || rH < 1) {
|
if (rW < 1 || rH < 1) {
|
||||||
s = Math.min(rW, rH)
|
s = Math.min(rW, rH)
|
||||||
@ -1054,6 +1067,27 @@ export default function Editor() {
|
|||||||
if (file === undefined) {
|
if (file === undefined) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if (enableFileManager && renders.length > 0) {
|
||||||
|
try {
|
||||||
|
downloadToOutput(renders[renders.length - 1], file.name, file.type)
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: `Save image success`,
|
||||||
|
state: 'success',
|
||||||
|
duration: 2000,
|
||||||
|
})
|
||||||
|
} catch (e: any) {
|
||||||
|
setToastState({
|
||||||
|
open: true,
|
||||||
|
desc: e.message ? e.message : e.toString(),
|
||||||
|
state: 'error',
|
||||||
|
duration: 2000,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: download to output directory
|
||||||
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
|
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
|
||||||
const curRender = renders[renders.length - 1]
|
const curRender = renders[renders.length - 1]
|
||||||
downloadImage(curRender.currentSrc, name)
|
downloadImage(curRender.currentSrc, name)
|
||||||
|
@ -6,14 +6,23 @@ type SliderProps = {
|
|||||||
min?: number
|
min?: number
|
||||||
max?: number
|
max?: number
|
||||||
onChange: (value: number) => void
|
onChange: (value: number) => void
|
||||||
onClick: () => void
|
onClick?: () => void
|
||||||
|
width?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function Slider(props: SliderProps) {
|
export default function Slider(props: SliderProps) {
|
||||||
const { value, onChange, onClick, label, min, max } = props
|
const { value, onChange, onClick, label, min, max, width } = props
|
||||||
|
const styles: any = {}
|
||||||
|
if (width !== undefined) {
|
||||||
|
styles.width = width
|
||||||
|
}
|
||||||
|
|
||||||
const step = ((max || 100) - (min || 0)) / 100
|
const step = ((max || 100) - (min || 0)) / 100
|
||||||
|
|
||||||
|
const onMouseUp = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||||
|
e.currentTarget?.blur()
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="editor-brush-slider">
|
<div className="editor-brush-slider">
|
||||||
<span>{label}</span>
|
<span>{label}</span>
|
||||||
@ -29,6 +38,8 @@ export default function Slider(props: SliderProps) {
|
|||||||
onChange(parseInt(ev.currentTarget.value, 10))
|
onChange(parseInt(ev.currentTarget.value, 10))
|
||||||
}}
|
}}
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
|
style={styles}
|
||||||
|
onMouseUp={onMouseUp}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ import React, {
|
|||||||
useState,
|
useState,
|
||||||
useCallback,
|
useCallback,
|
||||||
} from 'react'
|
} from 'react'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState } from 'recoil'
|
||||||
import PhotoAlbum, { RenderPhoto } from 'react-photo-album'
|
import PhotoAlbum, { RenderPhoto } from 'react-photo-album'
|
||||||
import * as ScrollArea from '@radix-ui/react-scroll-area'
|
import * as ScrollArea from '@radix-ui/react-scroll-area'
|
||||||
import Modal from '../shared/Modal'
|
import Modal from '../shared/Modal'
|
||||||
@ -127,7 +127,7 @@ export default function FileManager(props: Props) {
|
|||||||
ref={onRefChange}
|
ref={onRefChange}
|
||||||
>
|
>
|
||||||
<PhotoAlbum
|
<PhotoAlbum
|
||||||
layout="columns"
|
layout="masonry"
|
||||||
photos={photos}
|
photos={photos}
|
||||||
renderPhoto={renderPhoto}
|
renderPhoto={renderPhoto}
|
||||||
spacing={8}
|
spacing={8}
|
||||||
|
@ -21,7 +21,7 @@ function SettingBlock(props: SettingBlockProps) {
|
|||||||
<div className={`setting-block ${className}`}>
|
<div className={`setting-block ${className}`}>
|
||||||
<div className={contentClass}>
|
<div className={contentClass}>
|
||||||
<div className="setting-block-content-title">
|
<div className="setting-block-content-title">
|
||||||
<div style={{ display: 'flex', alignItems: 'center', gap: '12px' }}>
|
<div style={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
|
||||||
{desc ? (
|
{desc ? (
|
||||||
<Tooltip content={<div style={{ maxWidth: 400 }}>{desc}</div>}>
|
<Tooltip content={<div style={{ maxWidth: 400 }}>{desc}</div>}>
|
||||||
<span>{title}</span>
|
<span>{title}</span>
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import React, { FormEvent } from 'react'
|
import React from 'react'
|
||||||
|
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import { FolderOpenIcon } from '@heroicons/react/24/outline'
|
|
||||||
import * as Tabs from '@radix-ui/react-tabs'
|
|
||||||
import {
|
import {
|
||||||
isPaintByExampleState,
|
isPaintByExampleState,
|
||||||
isSDState,
|
isSDState,
|
||||||
@ -14,10 +11,6 @@ import HDSettingBlock from './HDSettingBlock'
|
|||||||
import ModelSettingBlock from './ModelSettingBlock'
|
import ModelSettingBlock from './ModelSettingBlock'
|
||||||
import DownloadMaskSettingBlock from './DownloadMaskSettingBlock'
|
import DownloadMaskSettingBlock from './DownloadMaskSettingBlock'
|
||||||
import useHotKey from '../../hooks/useHotkey'
|
import useHotKey from '../../hooks/useHotkey'
|
||||||
import SettingBlock from './SettingBlock'
|
|
||||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
|
||||||
import Button from '../shared/Button'
|
|
||||||
import TextInput from '../shared/Input'
|
|
||||||
|
|
||||||
declare module 'react' {
|
declare module 'react' {
|
||||||
interface InputHTMLAttributes<T> extends HTMLAttributes<T> {
|
interface InputHTMLAttributes<T> extends HTMLAttributes<T> {
|
||||||
|
@ -0,0 +1,58 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
|
import { appState, croperState, settingState } from '../../store/Atoms'
|
||||||
|
import Slider from '../Editor/Slider'
|
||||||
|
import SettingBlock from '../Settings/SettingBlock'
|
||||||
|
|
||||||
|
const ImageResizeScale = () => {
|
||||||
|
const [setting, setSettingState] = useRecoilState(settingState)
|
||||||
|
const app = useRecoilValue(appState)
|
||||||
|
const croper = useRecoilValue(croperState)
|
||||||
|
|
||||||
|
const handleSliderChange = (value: number) => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdScale: value }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const scaledWidth = () => {
|
||||||
|
let width = app.imageWidth
|
||||||
|
if (setting.showCroper) {
|
||||||
|
width = croper.width
|
||||||
|
}
|
||||||
|
return Math.round((width * setting.sdScale) / 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
const scaledHeight = () => {
|
||||||
|
let height = app.imageHeight
|
||||||
|
if (setting.showCroper) {
|
||||||
|
height = croper.height
|
||||||
|
}
|
||||||
|
return Math.round((height * setting.sdScale) / 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<SettingBlock
|
||||||
|
className="sub-setting-block"
|
||||||
|
title="Resize"
|
||||||
|
titleSuffix={
|
||||||
|
<div
|
||||||
|
style={{ width: 86 }}
|
||||||
|
>{`(${scaledWidth()}x${scaledHeight()})`}</div>
|
||||||
|
}
|
||||||
|
desc="Resize the image before inpainting, the area outside the mask will not lose quality."
|
||||||
|
input={
|
||||||
|
<Slider
|
||||||
|
label=""
|
||||||
|
width={70}
|
||||||
|
min={50}
|
||||||
|
max={100}
|
||||||
|
value={setting.sdScale}
|
||||||
|
onChange={handleSliderChange}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ImageResizeScale
|
@ -14,6 +14,7 @@ import { Switch, SwitchThumb } from '../shared/Switch'
|
|||||||
import Button from '../shared/Button'
|
import Button from '../shared/Button'
|
||||||
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
|
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
|
||||||
import { useImage } from '../../utils'
|
import { useImage } from '../../utils'
|
||||||
|
import ImageResizeScale from './ImageResizeScale'
|
||||||
|
|
||||||
const INPUT_WIDTH = 30
|
const INPUT_WIDTH = 30
|
||||||
|
|
||||||
@ -68,22 +69,6 @@ const PESidePanel = () => {
|
|||||||
</PopoverPrimitive.Trigger>
|
</PopoverPrimitive.Trigger>
|
||||||
<PopoverPrimitive.Portal>
|
<PopoverPrimitive.Portal>
|
||||||
<PopoverPrimitive.Content className="side-panel-content">
|
<PopoverPrimitive.Content className="side-panel-content">
|
||||||
<SettingBlock
|
|
||||||
title="Croper"
|
|
||||||
input={
|
|
||||||
<Switch
|
|
||||||
checked={setting.showCroper}
|
|
||||||
onCheckedChange={value => {
|
|
||||||
setSettingState(old => {
|
|
||||||
return { ...old, showCroper: value }
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<SwitchThumb />
|
|
||||||
</Switch>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<NumberInputSetting
|
<NumberInputSetting
|
||||||
title="Steps"
|
title="Steps"
|
||||||
width={INPUT_WIDTH}
|
width={INPUT_WIDTH}
|
||||||
@ -97,6 +82,8 @@ const PESidePanel = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<ImageResizeScale />
|
||||||
|
|
||||||
<NumberInputSetting
|
<NumberInputSetting
|
||||||
title="Guidance Scale"
|
title="Guidance Scale"
|
||||||
width={INPUT_WIDTH}
|
width={INPUT_WIDTH}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import React, { FormEvent, useState } from 'react'
|
import React, { FormEvent } from 'react'
|
||||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||||
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||||
import { useToggle } from 'react-use'
|
import { useToggle } from 'react-use'
|
||||||
@ -15,6 +15,7 @@ import Selector from '../shared/Selector'
|
|||||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||||
import TextAreaInput from '../shared/Textarea'
|
import TextAreaInput from '../shared/Textarea'
|
||||||
import emitter, { EVENT_PROMPT } from '../../event'
|
import emitter, { EVENT_PROMPT } from '../../event'
|
||||||
|
import ImageResizeScale from './ImageResizeScale'
|
||||||
|
|
||||||
const INPUT_WIDTH = 30
|
const INPUT_WIDTH = 30
|
||||||
|
|
||||||
@ -71,6 +72,9 @@ const SidePanel = () => {
|
|||||||
</Switch>
|
</Switch>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<ImageResizeScale />
|
||||||
|
|
||||||
{/*
|
{/*
|
||||||
<NumberInputSetting
|
<NumberInputSetting
|
||||||
title="Num Samples"
|
title="Num Samples"
|
||||||
@ -98,21 +102,6 @@ const SidePanel = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{/* <NumberInputSetting
|
|
||||||
title="Strength"
|
|
||||||
width={INPUT_WIDTH}
|
|
||||||
allowFloat
|
|
||||||
value={`${setting.sdStrength}`}
|
|
||||||
desc="TODO"
|
|
||||||
onValue={value => {
|
|
||||||
const val = value.length === 0 ? 0 : parseFloat(value)
|
|
||||||
console.log(val)
|
|
||||||
setSettingState(old => {
|
|
||||||
return { ...old, sdStrength: val }
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
/> */}
|
|
||||||
|
|
||||||
<NumberInputSetting
|
<NumberInputSetting
|
||||||
title="Guidance Scale"
|
title="Guidance Scale"
|
||||||
width={INPUT_WIDTH}
|
width={INPUT_WIDTH}
|
||||||
|
@ -35,6 +35,8 @@ export interface Rect {
|
|||||||
|
|
||||||
interface AppState {
|
interface AppState {
|
||||||
file: File | undefined
|
file: File | undefined
|
||||||
|
imageHeight: number
|
||||||
|
imageWidth: number
|
||||||
disableShortCuts: boolean
|
disableShortCuts: boolean
|
||||||
isInpainting: boolean
|
isInpainting: boolean
|
||||||
isDisableModelSwitch: boolean
|
isDisableModelSwitch: boolean
|
||||||
@ -49,6 +51,8 @@ export const appState = atom<AppState>({
|
|||||||
key: 'appState',
|
key: 'appState',
|
||||||
default: {
|
default: {
|
||||||
file: undefined,
|
file: undefined,
|
||||||
|
imageHeight: 0,
|
||||||
|
imageWidth: 0,
|
||||||
disableShortCuts: false,
|
disableShortCuts: false,
|
||||||
isInpainting: false,
|
isInpainting: false,
|
||||||
isDisableModelSwitch: false,
|
isDisableModelSwitch: false,
|
||||||
@ -82,6 +86,30 @@ export const isInpaintingState = selector({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const imageHeightState = selector({
|
||||||
|
key: 'imageHeightState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.imageHeight
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, imageHeight: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
export const imageWidthState = selector({
|
||||||
|
key: 'imageWidthState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.imageWidth
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, imageWidth: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const showFileManagerState = selector({
|
export const showFileManagerState = selector({
|
||||||
key: 'showFileManager',
|
key: 'showFileManager',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
@ -121,6 +149,12 @@ export const fileState = selector({
|
|||||||
isInteractiveSeg: false,
|
isInteractiveSeg: false,
|
||||||
isInteractiveSegRunning: false,
|
isInteractiveSegRunning: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const setting = get(settingState)
|
||||||
|
set(settingState, {
|
||||||
|
...setting,
|
||||||
|
sdScale: 100,
|
||||||
|
})
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -282,6 +316,7 @@ export interface Settings {
|
|||||||
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
|
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
|
||||||
sdNumSamples: number
|
sdNumSamples: number
|
||||||
sdMatchHistograms: boolean
|
sdMatchHistograms: boolean
|
||||||
|
sdScale: number
|
||||||
|
|
||||||
// For OpenCV2
|
// For OpenCV2
|
||||||
cv2Radius: number
|
cv2Radius: number
|
||||||
@ -409,6 +444,7 @@ export const settingStateDefault: Settings = {
|
|||||||
sdSeedFixed: true,
|
sdSeedFixed: true,
|
||||||
sdNumSamples: 1,
|
sdNumSamples: 1,
|
||||||
sdMatchHistograms: false,
|
sdMatchHistograms: false,
|
||||||
|
sdScale: 100,
|
||||||
|
|
||||||
// CV2
|
// CV2
|
||||||
cv2Radius: 5,
|
cv2Radius: 5,
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from cachetools import TTLCache, cached
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image, ImageOps, PngImagePlugin
|
from PIL import Image, ImageOps, PngImagePlugin
|
||||||
|
|
||||||
LARGE_ENOUGH_NUMBER = 100
|
LARGE_ENOUGH_NUMBER = 100
|
||||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2)
|
||||||
from .storage_backends import FilesystemStorageBackend
|
from .storage_backends import FilesystemStorageBackend
|
||||||
from .utils import aspect_to_string, generate_filename, glob_img
|
from .utils import aspect_to_string, generate_filename, glob_img
|
||||||
|
|
||||||
@ -18,6 +23,7 @@ class FileManager:
|
|||||||
self._default_root_url = "/"
|
self._default_root_url = "/"
|
||||||
self._default_thumbnail_root_url = "/"
|
self._default_thumbnail_root_url = "/"
|
||||||
self._default_format = "JPEG"
|
self._default_format = "JPEG"
|
||||||
|
self.output_dir: Path = None
|
||||||
|
|
||||||
if app is not None:
|
if app is not None:
|
||||||
self.init_app(app)
|
self.init_app(app)
|
||||||
@ -41,6 +47,16 @@ class FileManager:
|
|||||||
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
|
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
|
||||||
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
||||||
|
|
||||||
|
def save_to_output_directory(self, image: np.ndarray, filename: str):
|
||||||
|
fp = Path(filename)
|
||||||
|
new_name = fp.stem + f"_{int(time.time())}" + fp.suffix
|
||||||
|
if image.shape[2] == 3:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||||
|
elif image.shape[2] == 4:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)
|
||||||
|
|
||||||
|
cv2.imwrite(str(self.output_dir / new_name), image)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_directory(self):
|
def root_directory(self):
|
||||||
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
||||||
@ -64,7 +80,7 @@ class FileManager:
|
|||||||
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@lru_cache()
|
@cached(cache=TTLCache(maxsize=1024, ttl=30))
|
||||||
def media_names(self):
|
def media_names(self):
|
||||||
names = sorted([it.name for it in glob_img(self.root_directory)])
|
names = sorted([it.name for it in glob_img(self.root_directory)])
|
||||||
res = []
|
res = []
|
||||||
|
@ -6,6 +6,9 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.helper import resize_max_size
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
from lama_cleaner.schema import Config
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
@ -15,15 +18,20 @@ class PaintByExample(InpaintModel):
|
|||||||
min_size = 512
|
min_size = 512
|
||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
fp16 = not kwargs['no_half']
|
fp16 = not kwargs.get('no_half', False)
|
||||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
||||||
self.model = DiffusionPipeline.from_pretrained(
|
self.model = DiffusionPipeline.from_pretrained(
|
||||||
"Fantasy-Studio/Paint-by-Example",
|
"Fantasy-Studio/Paint-by-Example",
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
**model_kwargs
|
||||||
)
|
)
|
||||||
self.model.enable_attention_slicing()
|
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
# TODO: gpu_id
|
||||||
|
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
@ -49,6 +57,25 @@ class PaintByExample(InpaintModel):
|
|||||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _scaled_pad_forward(self, image, mask, config: Config):
|
||||||
|
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||||
|
origin_size = image.shape[:2]
|
||||||
|
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||||
|
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||||
|
logger.info(
|
||||||
|
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
|
||||||
|
)
|
||||||
|
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||||
|
# only paste masked area result
|
||||||
|
inpaint_result = cv2.resize(
|
||||||
|
inpaint_result,
|
||||||
|
(origin_size[1], origin_size[0]),
|
||||||
|
interpolation=cv2.INTER_CUBIC,
|
||||||
|
)
|
||||||
|
original_pixel_indices = mask < 127
|
||||||
|
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||||
|
return inpaint_result
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
"""
|
"""
|
||||||
@ -58,11 +85,11 @@ class PaintByExample(InpaintModel):
|
|||||||
"""
|
"""
|
||||||
if config.use_croper:
|
if config.use_croper:
|
||||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||||
inpaint_result = image[:, :, ::-1]
|
inpaint_result = image[:, :, ::-1]
|
||||||
inpaint_result[t:b, l:r, :] = crop_image
|
inpaint_result[t:b, l:r, :] = crop_image
|
||||||
else:
|
else:
|
||||||
inpaint_result = self._pad_forward(image, mask, config)
|
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
|
@ -8,7 +8,9 @@ from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerD
|
|||||||
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
|
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.helper import resize_max_size
|
||||||
from lama_cleaner.model.base import InpaintModel
|
from lama_cleaner.model.base import InpaintModel
|
||||||
|
from lama_cleaner.model.utils import torch_gc
|
||||||
from lama_cleaner.schema import Config, SDSampler
|
from lama_cleaner.schema import Config, SDSampler
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +20,8 @@ class CPUTextEncoderWrapper:
|
|||||||
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
|
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
|
||||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||||
self.torch_dtype = torch_dtype
|
self.torch_dtype = torch_dtype
|
||||||
|
del text_encoder
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
def __call__(self, x, **kwargs):
|
def __call__(self, x, **kwargs):
|
||||||
input_device = x.device
|
input_device = x.device
|
||||||
@ -30,9 +34,9 @@ class SD(InpaintModel):
|
|||||||
|
|
||||||
def init_model(self, device: torch.device, **kwargs):
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||||
fp16 = not kwargs['no_half']
|
fp16 = not kwargs.get('no_half', False)
|
||||||
|
|
||||||
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
|
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
|
||||||
if kwargs['sd_disable_nsfw']:
|
if kwargs['sd_disable_nsfw']:
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(dict(
|
model_kwargs.update(dict(
|
||||||
@ -48,19 +52,43 @@ class SD(InpaintModel):
|
|||||||
use_auth_token=kwargs["hf_access_token"],
|
use_auth_token=kwargs["hf_access_token"],
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
self.model.enable_attention_slicing()
|
self.model.enable_attention_slicing()
|
||||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||||
if kwargs.get('sd_enable_xformers', False):
|
if kwargs.get('sd_enable_xformers', False):
|
||||||
self.model.enable_xformers_memory_efficient_attention()
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
self.model = self.model.to(device)
|
|
||||||
|
|
||||||
if kwargs['sd_cpu_textencoder']:
|
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
||||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
# TODO: gpu_id
|
||||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
if kwargs['sd_cpu_textencoder']:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
def _scaled_pad_forward(self, image, mask, config: Config):
|
||||||
|
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||||
|
origin_size = image.shape[:2]
|
||||||
|
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||||
|
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||||
|
logger.info(
|
||||||
|
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
||||||
|
)
|
||||||
|
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||||
|
# only paste masked area result
|
||||||
|
inpaint_result = cv2.resize(
|
||||||
|
inpaint_result,
|
||||||
|
(origin_size[1], origin_size[0]),
|
||||||
|
interpolation=cv2.INTER_CUBIC,
|
||||||
|
)
|
||||||
|
original_pixel_indices = mask < 127
|
||||||
|
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||||
|
return inpaint_result
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
@ -126,11 +154,11 @@ class SD(InpaintModel):
|
|||||||
# boxes = boxes_from_mask(mask)
|
# boxes = boxes_from_mask(mask)
|
||||||
if config.use_croper:
|
if config.use_croper:
|
||||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||||
inpaint_result = image[:, :, ::-1]
|
inpaint_result = image[:, :, ::-1]
|
||||||
inpaint_result[t:b, l:r, :] = crop_image
|
inpaint_result[t:b, l:r, :] = crop_image
|
||||||
else:
|
else:
|
||||||
inpaint_result = self._pad_forward(image, mask, config)
|
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
|
@ -707,3 +707,9 @@ class Conv2dLayer(torch.nn.Module):
|
|||||||
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
||||||
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def torch_gc():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
@ -15,7 +15,9 @@ def parse_args():
|
|||||||
default="lama",
|
default="lama",
|
||||||
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
|
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
|
||||||
)
|
)
|
||||||
parser.add_argument("--no-half", action="store_true", help="SD/PaintByExample model no half precision")
|
parser.add_argument("--no-half", action="store_true", help="sd/paint_by_example model no half precision")
|
||||||
|
parser.add_argument("--cpu-offload", action="store_true",
|
||||||
|
help="sd/paint_by_example model, offloads all models to CPU, significantly reducing vRAM usage.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf_access_token",
|
"--hf_access_token",
|
||||||
default="",
|
default="",
|
||||||
@ -34,7 +36,12 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-run-local",
|
"--sd-run-local",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token",
|
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-files-only",
|
||||||
|
action="store_true",
|
||||||
|
help="sd/paint_by_example model. Use local files only, not connect to huggingface server",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-enable-xformers",
|
"--sd-enable-xformers",
|
||||||
@ -80,7 +87,7 @@ def parse_args():
|
|||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
parser.error(f"invalid --output-dir: {output_dir} is not a directory")
|
parser.error(f"invalid --output-dir: {output_dir} is not a directory")
|
||||||
|
|
||||||
if args.model == 'sd1.5' and not args.sd_run_local:
|
if args.model == 'sd1.5' and not (args.sd_run_local or args.local_files_only):
|
||||||
if not args.hf_access_token.startswith("hf_"):
|
if not args.hf_access_token.startswith("hf_"):
|
||||||
parser.error(
|
parser.error(
|
||||||
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"
|
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"
|
||||||
|
@ -58,6 +58,9 @@ class Config(BaseModel):
|
|||||||
croper_height: int = None
|
croper_height: int = None
|
||||||
croper_width: int = None
|
croper_width: int = None
|
||||||
|
|
||||||
|
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
|
||||||
|
# Used by sd models and paint_by_example model
|
||||||
|
sd_scale: float = 1.0
|
||||||
# Blur the edge of mask area. The higher the number the smoother blend with the original image
|
# Blur the edge of mask area. The higher the number the smoother blend with the original image
|
||||||
sd_mask_blur: int = 0
|
sd_mask_blur: int = 0
|
||||||
# Ignore this value, it's useless for inpainting
|
# Ignore this value, it's useless for inpainting
|
||||||
|
@ -31,7 +31,6 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
|
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
|
||||||
from flask_caching import Cache
|
|
||||||
|
|
||||||
# Disable ability for Flask to display warning about using a development server in a production environment.
|
# Disable ability for Flask to display warning about using a development server in a production environment.
|
||||||
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
||||||
@ -67,12 +66,9 @@ class NoFlaskwebgui(logging.Filter):
|
|||||||
|
|
||||||
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||||
|
|
||||||
cache = Cache(config={'CACHE_TYPE': 'SimpleCache'}, with_jinja2_ext=False)
|
|
||||||
|
|
||||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||||
app.config["JSON_AS_ASCII"] = False
|
app.config["JSON_AS_ASCII"] = False
|
||||||
CORS(app, expose_headers=["Content-Disposition"])
|
CORS(app, expose_headers=["Content-Disposition"])
|
||||||
cache.init_app(app)
|
|
||||||
|
|
||||||
model: ModelManager = None
|
model: ModelManager = None
|
||||||
thumb = FileManager(app)
|
thumb = FileManager(app)
|
||||||
@ -96,6 +92,16 @@ def diffuser_callback(i, t, latents):
|
|||||||
# socketio.emit('diffusion_step', {'diffusion_step': step})
|
# socketio.emit('diffusion_step', {'diffusion_step': step})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/save_image", methods=["POST"])
|
||||||
|
def save_image():
|
||||||
|
# all image in output directory
|
||||||
|
input = request.files
|
||||||
|
origin_image_bytes = input["image"].read() # RGB
|
||||||
|
image, _ = load_img(origin_image_bytes)
|
||||||
|
thumb.save_to_output_directory(image, request.form["filename"])
|
||||||
|
return 'ok', 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/medias")
|
@app.route("/medias")
|
||||||
def medias():
|
def medias():
|
||||||
# all images in input folder
|
# all images in input folder
|
||||||
@ -172,6 +178,7 @@ def process():
|
|||||||
croper_y=form["croperY"],
|
croper_y=form["croperY"],
|
||||||
croper_height=form["croperHeight"],
|
croper_height=form["croperHeight"],
|
||||||
croper_width=form["croperWidth"],
|
croper_width=form["croperWidth"],
|
||||||
|
sd_scale=form["sdScale"],
|
||||||
sd_mask_blur=form["sdMaskBlur"],
|
sd_mask_blur=form["sdMaskBlur"],
|
||||||
sd_strength=form["sdStrength"],
|
sd_strength=form["sdStrength"],
|
||||||
sd_steps=form["sdSteps"],
|
sd_steps=form["sdSteps"],
|
||||||
@ -345,6 +352,7 @@ def main(args):
|
|||||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'thumbnails')
|
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'thumbnails')
|
||||||
is_enable_file_manager = True
|
is_enable_file_manager = True
|
||||||
|
thumb.output_dir = Path(args.output_dir)
|
||||||
else:
|
else:
|
||||||
input_image_path = args.input
|
input_image_path = args.input
|
||||||
|
|
||||||
@ -356,6 +364,8 @@ def main(args):
|
|||||||
sd_disable_nsfw=args.sd_disable_nsfw,
|
sd_disable_nsfw=args.sd_disable_nsfw,
|
||||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||||
sd_run_local=args.sd_run_local,
|
sd_run_local=args.sd_run_local,
|
||||||
|
local_files_only=args.local_files_only,
|
||||||
|
cpu_offload=args.cpu_offload,
|
||||||
sd_enable_xformers=args.sd_enable_xformers,
|
sd_enable_xformers=args.sd_enable_xformers,
|
||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
)
|
)
|
||||||
|
@ -38,7 +38,7 @@ def assert_equal(
|
|||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
def test_paint_by_example(strategy):
|
def test_paint_by_example(strategy):
|
||||||
model = ModelManager(name="paint_by_example", device=device)
|
model = ModelManager(name="paint_by_example", device=device)
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30 if device == 'cuda' else 1)
|
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
cfg,
|
cfg,
|
||||||
@ -46,5 +46,35 @@ def test_paint_by_example(strategy):
|
|||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fy=0.9,
|
fy=0.9,
|
||||||
|
fx=1.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
def test_paint_by_example_sd_scale(strategy):
|
||||||
|
model = ModelManager(name="paint_by_example", device=device)
|
||||||
|
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"paint_by_example_{strategy.capitalize()}_sdscale.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fy=0.9,
|
||||||
|
fx=1.3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
def test_paint_by_example_cpu_offload(strategy):
|
||||||
|
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True)
|
||||||
|
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fy=0.9,
|
||||||
fx=1.3
|
fx=1.3
|
||||||
)
|
)
|
||||||
|
@ -99,8 +99,8 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=True,
|
sd_disable_nsfw=False,
|
||||||
sd_cpu_textencoder=True,
|
sd_cpu_textencoder=False,
|
||||||
callback=callback)
|
callback=callback)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
strategy,
|
strategy,
|
||||||
@ -121,3 +121,63 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
|||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fx=1
|
fx=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ['cuda'])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||||
|
@pytest.mark.parametrize("cpu_textencoder", [False])
|
||||||
|
@pytest.mark.parametrize("disable_nsfw", [False])
|
||||||
|
def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||||
|
if sd_device == 'cuda' and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 50 if sd_device == 'cuda' else 1
|
||||||
|
model = ModelManager(name="sd1.5",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=True,
|
||||||
|
sd_disable_nsfw=disable_nsfw,
|
||||||
|
sd_cpu_textencoder=cpu_textencoder)
|
||||||
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=1.3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_device", ['cuda'])
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||||
|
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
||||||
|
if sd_device == 'cuda' and not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
sd_steps = 50 if sd_device == 'cuda' else 1
|
||||||
|
model = ModelManager(name="sd1.5",
|
||||||
|
device=torch.device(sd_device),
|
||||||
|
hf_access_token="",
|
||||||
|
sd_run_local=True,
|
||||||
|
sd_disable_nsfw=False,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
cpu_offload=True)
|
||||||
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
name = f"device_{sd_device}_{sampler}"
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
@ -12,3 +12,4 @@ markupsafe==2.0.1
|
|||||||
scikit-image==0.19.3
|
scikit-image==0.19.3
|
||||||
diffusers[torch]==0.10.2
|
diffusers[torch]==0.10.2
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
|
cachetools==5.2.0
|
Loading…
Reference in New Issue
Block a user