lots of updates
This commit is contained in:
parent
2e8e52f7a5
commit
a22536becc
@ -1,8 +1,7 @@
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import React, { useCallback, useEffect, useMemo } from 'react'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import { nanoid } from 'nanoid'
|
||||
import useInputImage from './hooks/useInputImage'
|
||||
import LandingPage from './components/LandingPage/LandingPage'
|
||||
import { themeState } from './components/Header/ThemeChanger'
|
||||
import Workspace from './components/Workspace'
|
||||
import {
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Rect, Settings } from '../store/Atoms'
|
||||
import { dataURItoBlob } from '../utils'
|
||||
import { dataURItoBlob, srcToFile } from '../utils'
|
||||
|
||||
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
|
||||
|
||||
@ -57,6 +57,7 @@ export default async function inpaint(
|
||||
fd.append('sdSampler', settings.sdSampler.toString())
|
||||
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
||||
fd.append('sdMatchHistograms', settings.sdMatchHistograms ? 'true' : 'false')
|
||||
fd.append('sdScale', (settings.sdScale / 100).toString())
|
||||
|
||||
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||
@ -198,3 +199,27 @@ export async function getMedias() {
|
||||
const errMsg = await res.text()
|
||||
throw new Error(errMsg)
|
||||
}
|
||||
|
||||
export async function downloadToOutput(
|
||||
image: HTMLImageElement,
|
||||
filename: string,
|
||||
mimeType: string
|
||||
) {
|
||||
const file = await srcToFile(image.src, filename, mimeType)
|
||||
const fd = new FormData()
|
||||
fd.append('image', file)
|
||||
fd.append('filename', filename)
|
||||
|
||||
try {
|
||||
const res = await fetch(`${API_ENDPOINT}/save_image`, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
})
|
||||
if (!res.ok) {
|
||||
const errMsg = await res.text()
|
||||
throw new Error(errMsg)
|
||||
}
|
||||
} catch (error) {
|
||||
throw new Error(`Something went wrong: ${error}`)
|
||||
}
|
||||
}
|
||||
|
@ -18,7 +18,10 @@ import {
|
||||
} from 'react-zoom-pan-pinch'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
|
||||
import inpaint, { postInteractiveSeg } from '../../adapters/inpainting'
|
||||
import inpaint, {
|
||||
downloadToOutput,
|
||||
postInteractiveSeg,
|
||||
} from '../../adapters/inpainting'
|
||||
import Button from '../shared/Button'
|
||||
import Slider from './Slider'
|
||||
import SizeSelector from './SizeSelector'
|
||||
@ -34,7 +37,10 @@ import {
|
||||
} from '../../utils'
|
||||
import {
|
||||
croperState,
|
||||
enableFileManagerState,
|
||||
fileState,
|
||||
imageHeightState,
|
||||
imageWidthState,
|
||||
interactiveSegClicksState,
|
||||
isInpaintingState,
|
||||
isInteractiveSegRunningState,
|
||||
@ -173,6 +179,10 @@ export default function Editor() {
|
||||
const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([])
|
||||
const [redoCurLines, setRedoCurLines] = useState<Line[]>([])
|
||||
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
|
||||
const enableFileManager = useRecoilValue(enableFileManagerState)
|
||||
|
||||
const [imageWidth, setImageWidth] = useRecoilState(imageWidthState)
|
||||
const [imageHeight, setImageHeight] = useRecoilState(imageHeightState)
|
||||
|
||||
const draw = useCallback(
|
||||
(render: HTMLImageElement, lineGroup: LineGroup) => {
|
||||
@ -524,6 +534,9 @@ export default function Editor() {
|
||||
const rW = windowSize.width / original.naturalWidth
|
||||
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
|
||||
|
||||
setImageWidth(original.naturalWidth)
|
||||
setImageHeight(original.naturalHeight)
|
||||
|
||||
let s = 1.0
|
||||
if (rW < 1 || rH < 1) {
|
||||
s = Math.min(rW, rH)
|
||||
@ -1054,6 +1067,27 @@ export default function Editor() {
|
||||
if (file === undefined) {
|
||||
return
|
||||
}
|
||||
if (enableFileManager && renders.length > 0) {
|
||||
try {
|
||||
downloadToOutput(renders[renders.length - 1], file.name, file.type)
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: `Save image success`,
|
||||
state: 'success',
|
||||
duration: 2000,
|
||||
})
|
||||
} catch (e: any) {
|
||||
setToastState({
|
||||
open: true,
|
||||
desc: e.message ? e.message : e.toString(),
|
||||
state: 'error',
|
||||
duration: 2000,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: download to output directory
|
||||
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
|
||||
const curRender = renders[renders.length - 1]
|
||||
downloadImage(curRender.currentSrc, name)
|
||||
|
@ -6,14 +6,23 @@ type SliderProps = {
|
||||
min?: number
|
||||
max?: number
|
||||
onChange: (value: number) => void
|
||||
onClick: () => void
|
||||
onClick?: () => void
|
||||
width?: number
|
||||
}
|
||||
|
||||
export default function Slider(props: SliderProps) {
|
||||
const { value, onChange, onClick, label, min, max } = props
|
||||
const { value, onChange, onClick, label, min, max, width } = props
|
||||
const styles: any = {}
|
||||
if (width !== undefined) {
|
||||
styles.width = width
|
||||
}
|
||||
|
||||
const step = ((max || 100) - (min || 0)) / 100
|
||||
|
||||
const onMouseUp = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
e.currentTarget?.blur()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="editor-brush-slider">
|
||||
<span>{label}</span>
|
||||
@ -29,6 +38,8 @@ export default function Slider(props: SliderProps) {
|
||||
onChange(parseInt(ev.currentTarget.value, 10))
|
||||
}}
|
||||
onClick={onClick}
|
||||
style={styles}
|
||||
onMouseUp={onMouseUp}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ import React, {
|
||||
useState,
|
||||
useCallback,
|
||||
} from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import { useRecoilState } from 'recoil'
|
||||
import PhotoAlbum, { RenderPhoto } from 'react-photo-album'
|
||||
import * as ScrollArea from '@radix-ui/react-scroll-area'
|
||||
import Modal from '../shared/Modal'
|
||||
@ -127,7 +127,7 @@ export default function FileManager(props: Props) {
|
||||
ref={onRefChange}
|
||||
>
|
||||
<PhotoAlbum
|
||||
layout="columns"
|
||||
layout="masonry"
|
||||
photos={photos}
|
||||
renderPhoto={renderPhoto}
|
||||
spacing={8}
|
||||
|
@ -21,7 +21,7 @@ function SettingBlock(props: SettingBlockProps) {
|
||||
<div className={`setting-block ${className}`}>
|
||||
<div className={contentClass}>
|
||||
<div className="setting-block-content-title">
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: '12px' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
|
||||
{desc ? (
|
||||
<Tooltip content={<div style={{ maxWidth: 400 }}>{desc}</div>}>
|
||||
<span>{title}</span>
|
||||
|
@ -1,8 +1,5 @@
|
||||
import React, { FormEvent } from 'react'
|
||||
|
||||
import React from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import { FolderOpenIcon } from '@heroicons/react/24/outline'
|
||||
import * as Tabs from '@radix-ui/react-tabs'
|
||||
import {
|
||||
isPaintByExampleState,
|
||||
isSDState,
|
||||
@ -14,10 +11,6 @@ import HDSettingBlock from './HDSettingBlock'
|
||||
import ModelSettingBlock from './ModelSettingBlock'
|
||||
import DownloadMaskSettingBlock from './DownloadMaskSettingBlock'
|
||||
import useHotKey from '../../hooks/useHotkey'
|
||||
import SettingBlock from './SettingBlock'
|
||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||
import Button from '../shared/Button'
|
||||
import TextInput from '../shared/Input'
|
||||
|
||||
declare module 'react' {
|
||||
interface InputHTMLAttributes<T> extends HTMLAttributes<T> {
|
||||
|
@ -0,0 +1,58 @@
|
||||
import React from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import { appState, croperState, settingState } from '../../store/Atoms'
|
||||
import Slider from '../Editor/Slider'
|
||||
import SettingBlock from '../Settings/SettingBlock'
|
||||
|
||||
const ImageResizeScale = () => {
|
||||
const [setting, setSettingState] = useRecoilState(settingState)
|
||||
const app = useRecoilValue(appState)
|
||||
const croper = useRecoilValue(croperState)
|
||||
|
||||
const handleSliderChange = (value: number) => {
|
||||
setSettingState(old => {
|
||||
return { ...old, sdScale: value }
|
||||
})
|
||||
}
|
||||
|
||||
const scaledWidth = () => {
|
||||
let width = app.imageWidth
|
||||
if (setting.showCroper) {
|
||||
width = croper.width
|
||||
}
|
||||
return Math.round((width * setting.sdScale) / 100)
|
||||
}
|
||||
|
||||
const scaledHeight = () => {
|
||||
let height = app.imageHeight
|
||||
if (setting.showCroper) {
|
||||
height = croper.height
|
||||
}
|
||||
return Math.round((height * setting.sdScale) / 100)
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingBlock
|
||||
className="sub-setting-block"
|
||||
title="Resize"
|
||||
titleSuffix={
|
||||
<div
|
||||
style={{ width: 86 }}
|
||||
>{`(${scaledWidth()}x${scaledHeight()})`}</div>
|
||||
}
|
||||
desc="Resize the image before inpainting, the area outside the mask will not lose quality."
|
||||
input={
|
||||
<Slider
|
||||
label=""
|
||||
width={70}
|
||||
min={50}
|
||||
max={100}
|
||||
value={setting.sdScale}
|
||||
onChange={handleSliderChange}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default ImageResizeScale
|
@ -14,6 +14,7 @@ import { Switch, SwitchThumb } from '../shared/Switch'
|
||||
import Button from '../shared/Button'
|
||||
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
|
||||
import { useImage } from '../../utils'
|
||||
import ImageResizeScale from './ImageResizeScale'
|
||||
|
||||
const INPUT_WIDTH = 30
|
||||
|
||||
@ -68,22 +69,6 @@ const PESidePanel = () => {
|
||||
</PopoverPrimitive.Trigger>
|
||||
<PopoverPrimitive.Portal>
|
||||
<PopoverPrimitive.Content className="side-panel-content">
|
||||
<SettingBlock
|
||||
title="Croper"
|
||||
input={
|
||||
<Switch
|
||||
checked={setting.showCroper}
|
||||
onCheckedChange={value => {
|
||||
setSettingState(old => {
|
||||
return { ...old, showCroper: value }
|
||||
})
|
||||
}}
|
||||
>
|
||||
<SwitchThumb />
|
||||
</Switch>
|
||||
}
|
||||
/>
|
||||
|
||||
<NumberInputSetting
|
||||
title="Steps"
|
||||
width={INPUT_WIDTH}
|
||||
@ -97,6 +82,8 @@ const PESidePanel = () => {
|
||||
}}
|
||||
/>
|
||||
|
||||
<ImageResizeScale />
|
||||
|
||||
<NumberInputSetting
|
||||
title="Guidance Scale"
|
||||
width={INPUT_WIDTH}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { FormEvent, useState } from 'react'
|
||||
import React, { FormEvent } from 'react'
|
||||
import { useRecoilState, useRecoilValue } from 'recoil'
|
||||
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||
import { useToggle } from 'react-use'
|
||||
@ -15,6 +15,7 @@ import Selector from '../shared/Selector'
|
||||
import { Switch, SwitchThumb } from '../shared/Switch'
|
||||
import TextAreaInput from '../shared/Textarea'
|
||||
import emitter, { EVENT_PROMPT } from '../../event'
|
||||
import ImageResizeScale from './ImageResizeScale'
|
||||
|
||||
const INPUT_WIDTH = 30
|
||||
|
||||
@ -71,6 +72,9 @@ const SidePanel = () => {
|
||||
</Switch>
|
||||
}
|
||||
/>
|
||||
|
||||
<ImageResizeScale />
|
||||
|
||||
{/*
|
||||
<NumberInputSetting
|
||||
title="Num Samples"
|
||||
@ -98,21 +102,6 @@ const SidePanel = () => {
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* <NumberInputSetting
|
||||
title="Strength"
|
||||
width={INPUT_WIDTH}
|
||||
allowFloat
|
||||
value={`${setting.sdStrength}`}
|
||||
desc="TODO"
|
||||
onValue={value => {
|
||||
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||
console.log(val)
|
||||
setSettingState(old => {
|
||||
return { ...old, sdStrength: val }
|
||||
})
|
||||
}}
|
||||
/> */}
|
||||
|
||||
<NumberInputSetting
|
||||
title="Guidance Scale"
|
||||
width={INPUT_WIDTH}
|
||||
|
@ -35,6 +35,8 @@ export interface Rect {
|
||||
|
||||
interface AppState {
|
||||
file: File | undefined
|
||||
imageHeight: number
|
||||
imageWidth: number
|
||||
disableShortCuts: boolean
|
||||
isInpainting: boolean
|
||||
isDisableModelSwitch: boolean
|
||||
@ -49,6 +51,8 @@ export const appState = atom<AppState>({
|
||||
key: 'appState',
|
||||
default: {
|
||||
file: undefined,
|
||||
imageHeight: 0,
|
||||
imageWidth: 0,
|
||||
disableShortCuts: false,
|
||||
isInpainting: false,
|
||||
isDisableModelSwitch: false,
|
||||
@ -82,6 +86,30 @@ export const isInpaintingState = selector({
|
||||
},
|
||||
})
|
||||
|
||||
export const imageHeightState = selector({
|
||||
key: 'imageHeightState',
|
||||
get: ({ get }) => {
|
||||
const app = get(appState)
|
||||
return app.imageHeight
|
||||
},
|
||||
set: ({ get, set }, newValue: any) => {
|
||||
const app = get(appState)
|
||||
set(appState, { ...app, imageHeight: newValue })
|
||||
},
|
||||
})
|
||||
|
||||
export const imageWidthState = selector({
|
||||
key: 'imageWidthState',
|
||||
get: ({ get }) => {
|
||||
const app = get(appState)
|
||||
return app.imageWidth
|
||||
},
|
||||
set: ({ get, set }, newValue: any) => {
|
||||
const app = get(appState)
|
||||
set(appState, { ...app, imageWidth: newValue })
|
||||
},
|
||||
})
|
||||
|
||||
export const showFileManagerState = selector({
|
||||
key: 'showFileManager',
|
||||
get: ({ get }) => {
|
||||
@ -121,6 +149,12 @@ export const fileState = selector({
|
||||
isInteractiveSeg: false,
|
||||
isInteractiveSegRunning: false,
|
||||
})
|
||||
|
||||
const setting = get(settingState)
|
||||
set(settingState, {
|
||||
...setting,
|
||||
sdScale: 100,
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
@ -282,6 +316,7 @@ export interface Settings {
|
||||
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
|
||||
sdNumSamples: number
|
||||
sdMatchHistograms: boolean
|
||||
sdScale: number
|
||||
|
||||
// For OpenCV2
|
||||
cv2Radius: number
|
||||
@ -409,6 +444,7 @@ export const settingStateDefault: Settings = {
|
||||
sdSeedFixed: true,
|
||||
sdNumSamples: 1,
|
||||
sdMatchHistograms: false,
|
||||
sdScale: 100,
|
||||
|
||||
// CV2
|
||||
cv2Radius: 5,
|
||||
|
@ -1,11 +1,16 @@
|
||||
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from cachetools import TTLCache, cached
|
||||
import cv2
|
||||
import time
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2)
|
||||
from .storage_backends import FilesystemStorageBackend
|
||||
from .utils import aspect_to_string, generate_filename, glob_img
|
||||
|
||||
@ -18,6 +23,7 @@ class FileManager:
|
||||
self._default_root_url = "/"
|
||||
self._default_thumbnail_root_url = "/"
|
||||
self._default_format = "JPEG"
|
||||
self.output_dir: Path = None
|
||||
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
@ -41,6 +47,16 @@ class FileManager:
|
||||
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
|
||||
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
||||
|
||||
def save_to_output_directory(self, image: np.ndarray, filename: str):
|
||||
fp = Path(filename)
|
||||
new_name = fp.stem + f"_{int(time.time())}" + fp.suffix
|
||||
if image.shape[2] == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
elif image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)
|
||||
|
||||
cv2.imwrite(str(self.output_dir / new_name), image)
|
||||
|
||||
@property
|
||||
def root_directory(self):
|
||||
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
||||
@ -64,7 +80,7 @@ class FileManager:
|
||||
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
@cached(cache=TTLCache(maxsize=1024, ttl=30))
|
||||
def media_names(self):
|
||||
names = sorted([it.name for it in glob_img(self.root_directory)])
|
||||
res = []
|
||||
|
@ -6,6 +6,9 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import resize_max_size
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@ -15,15 +18,20 @@ class PaintByExample(InpaintModel):
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
fp16 = not kwargs['no_half']
|
||||
fp16 = not kwargs.get('no_half', False)
|
||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
||||
self.model = DiffusionPipeline.from_pretrained(
|
||||
"Fantasy-Studio/Paint-by-Example",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs
|
||||
)
|
||||
self.model.enable_attention_slicing()
|
||||
self.model = self.model.to(device)
|
||||
self.model.enable_attention_slicing()
|
||||
# TODO: gpu_id
|
||||
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
@ -49,6 +57,25 @@ class PaintByExample(InpaintModel):
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
def _scaled_pad_forward(self, image, mask, config: Config):
|
||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||
logger.info(
|
||||
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
|
||||
)
|
||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(
|
||||
inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
original_pixel_indices = mask < 127
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
return inpaint_result
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
@ -58,11 +85,11 @@ class PaintByExample(InpaintModel):
|
||||
"""
|
||||
if config.use_croper:
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
else:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
|
@ -8,7 +8,9 @@ from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerD
|
||||
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import resize_max_size
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.schema import Config, SDSampler
|
||||
|
||||
|
||||
@ -18,6 +20,8 @@ class CPUTextEncoderWrapper:
|
||||
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
self.torch_dtype = torch_dtype
|
||||
del text_encoder
|
||||
torch_gc()
|
||||
|
||||
def __call__(self, x, **kwargs):
|
||||
input_device = x.device
|
||||
@ -30,9 +34,9 @@ class SD(InpaintModel):
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||
fp16 = not kwargs['no_half']
|
||||
fp16 = not kwargs.get('no_half', False)
|
||||
|
||||
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
|
||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
|
||||
if kwargs['sd_disable_nsfw']:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(dict(
|
||||
@ -48,19 +52,43 @@ class SD(InpaintModel):
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs
|
||||
)
|
||||
self.model = self.model.to(device)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||
if kwargs.get('sd_enable_xformers', False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
self.model = self.model.to(device)
|
||||
|
||||
if kwargs['sd_cpu_textencoder']:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
||||
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
||||
# TODO: gpu_id
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
if kwargs['sd_cpu_textencoder']:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def _scaled_pad_forward(self, image, mask, config: Config):
|
||||
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
||||
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
||||
logger.info(
|
||||
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
||||
)
|
||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(
|
||||
inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
original_pixel_indices = mask < 127
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
return inpaint_result
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@ -126,11 +154,11 @@ class SD(InpaintModel):
|
||||
# boxes = boxes_from_mask(mask)
|
||||
if config.use_croper:
|
||||
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
else:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
|
@ -707,3 +707,9 @@ class Conv2dLayer(torch.nn.Module):
|
||||
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
||||
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
||||
return out
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
@ -15,7 +15,9 @@ def parse_args():
|
||||
default="lama",
|
||||
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
|
||||
)
|
||||
parser.add_argument("--no-half", action="store_true", help="SD/PaintByExample model no half precision")
|
||||
parser.add_argument("--no-half", action="store_true", help="sd/paint_by_example model no half precision")
|
||||
parser.add_argument("--cpu-offload", action="store_true",
|
||||
help="sd/paint_by_example model, offloads all models to CPU, significantly reducing vRAM usage.")
|
||||
parser.add_argument(
|
||||
"--hf_access_token",
|
||||
default="",
|
||||
@ -34,7 +36,12 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--sd-run-local",
|
||||
action="store_true",
|
||||
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token",
|
||||
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
action="store_true",
|
||||
help="sd/paint_by_example model. Use local files only, not connect to huggingface server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd-enable-xformers",
|
||||
@ -80,7 +87,7 @@ def parse_args():
|
||||
if not output_dir.is_dir():
|
||||
parser.error(f"invalid --output-dir: {output_dir} is not a directory")
|
||||
|
||||
if args.model == 'sd1.5' and not args.sd_run_local:
|
||||
if args.model == 'sd1.5' and not (args.sd_run_local or args.local_files_only):
|
||||
if not args.hf_access_token.startswith("hf_"):
|
||||
parser.error(
|
||||
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"
|
||||
|
@ -58,6 +58,9 @@ class Config(BaseModel):
|
||||
croper_height: int = None
|
||||
croper_width: int = None
|
||||
|
||||
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
|
||||
# Used by sd models and paint_by_example model
|
||||
sd_scale: float = 1.0
|
||||
# Blur the edge of mask area. The higher the number the smoother blend with the original image
|
||||
sd_mask_blur: int = 0
|
||||
# Ignore this value, it's useless for inpainting
|
||||
|
@ -31,7 +31,6 @@ except:
|
||||
pass
|
||||
|
||||
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
|
||||
from flask_caching import Cache
|
||||
|
||||
# Disable ability for Flask to display warning about using a development server in a production environment.
|
||||
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
||||
@ -67,12 +66,9 @@ class NoFlaskwebgui(logging.Filter):
|
||||
|
||||
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||
|
||||
cache = Cache(config={'CACHE_TYPE': 'SimpleCache'}, with_jinja2_ext=False)
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
cache.init_app(app)
|
||||
|
||||
model: ModelManager = None
|
||||
thumb = FileManager(app)
|
||||
@ -96,6 +92,16 @@ def diffuser_callback(i, t, latents):
|
||||
# socketio.emit('diffusion_step', {'diffusion_step': step})
|
||||
|
||||
|
||||
@app.route("/save_image", methods=["POST"])
|
||||
def save_image():
|
||||
# all image in output directory
|
||||
input = request.files
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
image, _ = load_img(origin_image_bytes)
|
||||
thumb.save_to_output_directory(image, request.form["filename"])
|
||||
return 'ok', 200
|
||||
|
||||
|
||||
@app.route("/medias")
|
||||
def medias():
|
||||
# all images in input folder
|
||||
@ -172,6 +178,7 @@ def process():
|
||||
croper_y=form["croperY"],
|
||||
croper_height=form["croperHeight"],
|
||||
croper_width=form["croperWidth"],
|
||||
sd_scale=form["sdScale"],
|
||||
sd_mask_blur=form["sdMaskBlur"],
|
||||
sd_strength=form["sdStrength"],
|
||||
sd_steps=form["sdSteps"],
|
||||
@ -345,6 +352,7 @@ def main(args):
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'thumbnails')
|
||||
is_enable_file_manager = True
|
||||
thumb.output_dir = Path(args.output_dir)
|
||||
else:
|
||||
input_image_path = args.input
|
||||
|
||||
@ -356,6 +364,8 @@ def main(args):
|
||||
sd_disable_nsfw=args.sd_disable_nsfw,
|
||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||
sd_run_local=args.sd_run_local,
|
||||
local_files_only=args.local_files_only,
|
||||
cpu_offload=args.cpu_offload,
|
||||
sd_enable_xformers=args.sd_enable_xformers,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
@ -38,7 +38,7 @@ def assert_equal(
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30 if device == 'cuda' else 1)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
@ -46,5 +46,35 @@ def test_paint_by_example(strategy):
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fy=0.9,
|
||||
fx=1.3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example_sd_scale(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"paint_by_example_{strategy.capitalize()}_sdscale.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fy=0.9,
|
||||
fx=1.3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example_cpu_offload(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fy=0.9,
|
||||
fx=1.3
|
||||
)
|
||||
|
@ -99,8 +99,8 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
sd_disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback)
|
||||
cfg = get_config(
|
||||
strategy,
|
||||
@ -121,3 +121,63 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||
@pytest.mark.parametrize("cpu_textencoder", [False])
|
||||
@pytest.mark.parametrize("disable_nsfw", [False])
|
||||
def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
||||
model = ModelManager(name="sd1.5",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
||||
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
||||
if sd_device == 'cuda' and not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
sd_steps = 50 if sd_device == 'cuda' else 1
|
||||
model = ModelManager(name="sd1.5",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{sd_device}_{sampler}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
@ -12,3 +12,4 @@ markupsafe==2.0.1
|
||||
scikit-image==0.19.3
|
||||
diffusers[torch]==0.10.2
|
||||
transformers>=4.25.1
|
||||
cachetools==5.2.0
|
Loading…
Reference in New Issue
Block a user