lots of updates

This commit is contained in:
Qing 2023-01-05 22:07:39 +08:00
parent 2e8e52f7a5
commit a22536becc
21 changed files with 394 additions and 74 deletions

View File

@ -1,8 +1,7 @@
import React, { useCallback, useEffect, useMemo, useState } from 'react' import React, { useCallback, useEffect, useMemo } from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { nanoid } from 'nanoid' import { nanoid } from 'nanoid'
import useInputImage from './hooks/useInputImage' import useInputImage from './hooks/useInputImage'
import LandingPage from './components/LandingPage/LandingPage'
import { themeState } from './components/Header/ThemeChanger' import { themeState } from './components/Header/ThemeChanger'
import Workspace from './components/Workspace' import Workspace from './components/Workspace'
import { import {

View File

@ -1,5 +1,5 @@
import { Rect, Settings } from '../store/Atoms' import { Rect, Settings } from '../store/Atoms'
import { dataURItoBlob } from '../utils' import { dataURItoBlob, srcToFile } from '../utils'
export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}` export const API_ENDPOINT = `${process.env.REACT_APP_INPAINTING_URL}`
@ -57,6 +57,7 @@ export default async function inpaint(
fd.append('sdSampler', settings.sdSampler.toString()) fd.append('sdSampler', settings.sdSampler.toString())
fd.append('sdSeed', seed ? seed.toString() : '-1') fd.append('sdSeed', seed ? seed.toString() : '-1')
fd.append('sdMatchHistograms', settings.sdMatchHistograms ? 'true' : 'false') fd.append('sdMatchHistograms', settings.sdMatchHistograms ? 'true' : 'false')
fd.append('sdScale', (settings.sdScale / 100).toString())
fd.append('cv2Radius', settings.cv2Radius.toString()) fd.append('cv2Radius', settings.cv2Radius.toString())
fd.append('cv2Flag', settings.cv2Flag.toString()) fd.append('cv2Flag', settings.cv2Flag.toString())
@ -198,3 +199,27 @@ export async function getMedias() {
const errMsg = await res.text() const errMsg = await res.text()
throw new Error(errMsg) throw new Error(errMsg)
} }
export async function downloadToOutput(
image: HTMLImageElement,
filename: string,
mimeType: string
) {
const file = await srcToFile(image.src, filename, mimeType)
const fd = new FormData()
fd.append('image', file)
fd.append('filename', filename)
try {
const res = await fetch(`${API_ENDPOINT}/save_image`, {
method: 'POST',
body: fd,
})
if (!res.ok) {
const errMsg = await res.text()
throw new Error(errMsg)
}
} catch (error) {
throw new Error(`Something went wrong: ${error}`)
}
}

View File

@ -18,7 +18,10 @@ import {
} from 'react-zoom-pan-pinch' } from 'react-zoom-pan-pinch'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import { useWindowSize, useKey, useKeyPressEvent } from 'react-use' import { useWindowSize, useKey, useKeyPressEvent } from 'react-use'
import inpaint, { postInteractiveSeg } from '../../adapters/inpainting' import inpaint, {
downloadToOutput,
postInteractiveSeg,
} from '../../adapters/inpainting'
import Button from '../shared/Button' import Button from '../shared/Button'
import Slider from './Slider' import Slider from './Slider'
import SizeSelector from './SizeSelector' import SizeSelector from './SizeSelector'
@ -34,7 +37,10 @@ import {
} from '../../utils' } from '../../utils'
import { import {
croperState, croperState,
enableFileManagerState,
fileState, fileState,
imageHeightState,
imageWidthState,
interactiveSegClicksState, interactiveSegClicksState,
isInpaintingState, isInpaintingState,
isInteractiveSegRunningState, isInteractiveSegRunningState,
@ -173,6 +179,10 @@ export default function Editor() {
const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([]) const [redoRenders, setRedoRenders] = useState<HTMLImageElement[]>([])
const [redoCurLines, setRedoCurLines] = useState<Line[]>([]) const [redoCurLines, setRedoCurLines] = useState<Line[]>([])
const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([]) const [redoLineGroups, setRedoLineGroups] = useState<LineGroup[]>([])
const enableFileManager = useRecoilValue(enableFileManagerState)
const [imageWidth, setImageWidth] = useRecoilState(imageWidthState)
const [imageHeight, setImageHeight] = useRecoilState(imageHeightState)
const draw = useCallback( const draw = useCallback(
(render: HTMLImageElement, lineGroup: LineGroup) => { (render: HTMLImageElement, lineGroup: LineGroup) => {
@ -524,6 +534,9 @@ export default function Editor() {
const rW = windowSize.width / original.naturalWidth const rW = windowSize.width / original.naturalWidth
const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight const rH = (windowSize.height - TOOLBAR_SIZE) / original.naturalHeight
setImageWidth(original.naturalWidth)
setImageHeight(original.naturalHeight)
let s = 1.0 let s = 1.0
if (rW < 1 || rH < 1) { if (rW < 1 || rH < 1) {
s = Math.min(rW, rH) s = Math.min(rW, rH)
@ -1054,6 +1067,27 @@ export default function Editor() {
if (file === undefined) { if (file === undefined) {
return return
} }
if (enableFileManager && renders.length > 0) {
try {
downloadToOutput(renders[renders.length - 1], file.name, file.type)
setToastState({
open: true,
desc: `Save image success`,
state: 'success',
duration: 2000,
})
} catch (e: any) {
setToastState({
open: true,
desc: e.message ? e.message : e.toString(),
state: 'error',
duration: 2000,
})
}
return
}
// TODO: download to output directory
const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1') const name = file.name.replace(/(\.[\w\d_-]+)$/i, '_cleanup$1')
const curRender = renders[renders.length - 1] const curRender = renders[renders.length - 1]
downloadImage(curRender.currentSrc, name) downloadImage(curRender.currentSrc, name)

View File

@ -6,14 +6,23 @@ type SliderProps = {
min?: number min?: number
max?: number max?: number
onChange: (value: number) => void onChange: (value: number) => void
onClick: () => void onClick?: () => void
width?: number
} }
export default function Slider(props: SliderProps) { export default function Slider(props: SliderProps) {
const { value, onChange, onClick, label, min, max } = props const { value, onChange, onClick, label, min, max, width } = props
const styles: any = {}
if (width !== undefined) {
styles.width = width
}
const step = ((max || 100) - (min || 0)) / 100 const step = ((max || 100) - (min || 0)) / 100
const onMouseUp = (e: React.MouseEvent<HTMLDivElement>) => {
e.currentTarget?.blur()
}
return ( return (
<div className="editor-brush-slider"> <div className="editor-brush-slider">
<span>{label}</span> <span>{label}</span>
@ -29,6 +38,8 @@ export default function Slider(props: SliderProps) {
onChange(parseInt(ev.currentTarget.value, 10)) onChange(parseInt(ev.currentTarget.value, 10))
}} }}
onClick={onClick} onClick={onClick}
style={styles}
onMouseUp={onMouseUp}
/> />
</div> </div>
) )

View File

@ -5,7 +5,7 @@ import React, {
useState, useState,
useCallback, useCallback,
} from 'react' } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState } from 'recoil'
import PhotoAlbum, { RenderPhoto } from 'react-photo-album' import PhotoAlbum, { RenderPhoto } from 'react-photo-album'
import * as ScrollArea from '@radix-ui/react-scroll-area' import * as ScrollArea from '@radix-ui/react-scroll-area'
import Modal from '../shared/Modal' import Modal from '../shared/Modal'
@ -127,7 +127,7 @@ export default function FileManager(props: Props) {
ref={onRefChange} ref={onRefChange}
> >
<PhotoAlbum <PhotoAlbum
layout="columns" layout="masonry"
photos={photos} photos={photos}
renderPhoto={renderPhoto} renderPhoto={renderPhoto}
spacing={8} spacing={8}

View File

@ -21,7 +21,7 @@ function SettingBlock(props: SettingBlockProps) {
<div className={`setting-block ${className}`}> <div className={`setting-block ${className}`}>
<div className={contentClass}> <div className={contentClass}>
<div className="setting-block-content-title"> <div className="setting-block-content-title">
<div style={{ display: 'flex', alignItems: 'center', gap: '12px' }}> <div style={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
{desc ? ( {desc ? (
<Tooltip content={<div style={{ maxWidth: 400 }}>{desc}</div>}> <Tooltip content={<div style={{ maxWidth: 400 }}>{desc}</div>}>
<span>{title}</span> <span>{title}</span>

View File

@ -1,8 +1,5 @@
import React, { FormEvent } from 'react' import React from 'react'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import { FolderOpenIcon } from '@heroicons/react/24/outline'
import * as Tabs from '@radix-ui/react-tabs'
import { import {
isPaintByExampleState, isPaintByExampleState,
isSDState, isSDState,
@ -14,10 +11,6 @@ import HDSettingBlock from './HDSettingBlock'
import ModelSettingBlock from './ModelSettingBlock' import ModelSettingBlock from './ModelSettingBlock'
import DownloadMaskSettingBlock from './DownloadMaskSettingBlock' import DownloadMaskSettingBlock from './DownloadMaskSettingBlock'
import useHotKey from '../../hooks/useHotkey' import useHotKey from '../../hooks/useHotkey'
import SettingBlock from './SettingBlock'
import { Switch, SwitchThumb } from '../shared/Switch'
import Button from '../shared/Button'
import TextInput from '../shared/Input'
declare module 'react' { declare module 'react' {
interface InputHTMLAttributes<T> extends HTMLAttributes<T> { interface InputHTMLAttributes<T> extends HTMLAttributes<T> {

View File

@ -0,0 +1,58 @@
import React from 'react'
import { useRecoilState, useRecoilValue } from 'recoil'
import { appState, croperState, settingState } from '../../store/Atoms'
import Slider from '../Editor/Slider'
import SettingBlock from '../Settings/SettingBlock'
const ImageResizeScale = () => {
const [setting, setSettingState] = useRecoilState(settingState)
const app = useRecoilValue(appState)
const croper = useRecoilValue(croperState)
const handleSliderChange = (value: number) => {
setSettingState(old => {
return { ...old, sdScale: value }
})
}
const scaledWidth = () => {
let width = app.imageWidth
if (setting.showCroper) {
width = croper.width
}
return Math.round((width * setting.sdScale) / 100)
}
const scaledHeight = () => {
let height = app.imageHeight
if (setting.showCroper) {
height = croper.height
}
return Math.round((height * setting.sdScale) / 100)
}
return (
<SettingBlock
className="sub-setting-block"
title="Resize"
titleSuffix={
<div
style={{ width: 86 }}
>{`(${scaledWidth()}x${scaledHeight()})`}</div>
}
desc="Resize the image before inpainting, the area outside the mask will not lose quality."
input={
<Slider
label=""
width={70}
min={50}
max={100}
value={setting.sdScale}
onChange={handleSliderChange}
/>
}
/>
)
}
export default ImageResizeScale

View File

@ -14,6 +14,7 @@ import { Switch, SwitchThumb } from '../shared/Switch'
import Button from '../shared/Button' import Button from '../shared/Button'
import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event' import emitter, { EVENT_PAINT_BY_EXAMPLE } from '../../event'
import { useImage } from '../../utils' import { useImage } from '../../utils'
import ImageResizeScale from './ImageResizeScale'
const INPUT_WIDTH = 30 const INPUT_WIDTH = 30
@ -68,22 +69,6 @@ const PESidePanel = () => {
</PopoverPrimitive.Trigger> </PopoverPrimitive.Trigger>
<PopoverPrimitive.Portal> <PopoverPrimitive.Portal>
<PopoverPrimitive.Content className="side-panel-content"> <PopoverPrimitive.Content className="side-panel-content">
<SettingBlock
title="Croper"
input={
<Switch
checked={setting.showCroper}
onCheckedChange={value => {
setSettingState(old => {
return { ...old, showCroper: value }
})
}}
>
<SwitchThumb />
</Switch>
}
/>
<NumberInputSetting <NumberInputSetting
title="Steps" title="Steps"
width={INPUT_WIDTH} width={INPUT_WIDTH}
@ -97,6 +82,8 @@ const PESidePanel = () => {
}} }}
/> />
<ImageResizeScale />
<NumberInputSetting <NumberInputSetting
title="Guidance Scale" title="Guidance Scale"
width={INPUT_WIDTH} width={INPUT_WIDTH}

View File

@ -1,4 +1,4 @@
import React, { FormEvent, useState } from 'react' import React, { FormEvent } from 'react'
import { useRecoilState, useRecoilValue } from 'recoil' import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover' import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use' import { useToggle } from 'react-use'
@ -15,6 +15,7 @@ import Selector from '../shared/Selector'
import { Switch, SwitchThumb } from '../shared/Switch' import { Switch, SwitchThumb } from '../shared/Switch'
import TextAreaInput from '../shared/Textarea' import TextAreaInput from '../shared/Textarea'
import emitter, { EVENT_PROMPT } from '../../event' import emitter, { EVENT_PROMPT } from '../../event'
import ImageResizeScale from './ImageResizeScale'
const INPUT_WIDTH = 30 const INPUT_WIDTH = 30
@ -71,6 +72,9 @@ const SidePanel = () => {
</Switch> </Switch>
} }
/> />
<ImageResizeScale />
{/* {/*
<NumberInputSetting <NumberInputSetting
title="Num Samples" title="Num Samples"
@ -98,21 +102,6 @@ const SidePanel = () => {
}} }}
/> />
{/* <NumberInputSetting
title="Strength"
width={INPUT_WIDTH}
allowFloat
value={`${setting.sdStrength}`}
desc="TODO"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
console.log(val)
setSettingState(old => {
return { ...old, sdStrength: val }
})
}}
/> */}
<NumberInputSetting <NumberInputSetting
title="Guidance Scale" title="Guidance Scale"
width={INPUT_WIDTH} width={INPUT_WIDTH}

View File

@ -35,6 +35,8 @@ export interface Rect {
interface AppState { interface AppState {
file: File | undefined file: File | undefined
imageHeight: number
imageWidth: number
disableShortCuts: boolean disableShortCuts: boolean
isInpainting: boolean isInpainting: boolean
isDisableModelSwitch: boolean isDisableModelSwitch: boolean
@ -49,6 +51,8 @@ export const appState = atom<AppState>({
key: 'appState', key: 'appState',
default: { default: {
file: undefined, file: undefined,
imageHeight: 0,
imageWidth: 0,
disableShortCuts: false, disableShortCuts: false,
isInpainting: false, isInpainting: false,
isDisableModelSwitch: false, isDisableModelSwitch: false,
@ -82,6 +86,30 @@ export const isInpaintingState = selector({
}, },
}) })
export const imageHeightState = selector({
key: 'imageHeightState',
get: ({ get }) => {
const app = get(appState)
return app.imageHeight
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, imageHeight: newValue })
},
})
export const imageWidthState = selector({
key: 'imageWidthState',
get: ({ get }) => {
const app = get(appState)
return app.imageWidth
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, imageWidth: newValue })
},
})
export const showFileManagerState = selector({ export const showFileManagerState = selector({
key: 'showFileManager', key: 'showFileManager',
get: ({ get }) => { get: ({ get }) => {
@ -121,6 +149,12 @@ export const fileState = selector({
isInteractiveSeg: false, isInteractiveSeg: false,
isInteractiveSegRunning: false, isInteractiveSegRunning: false,
}) })
const setting = get(settingState)
set(settingState, {
...setting,
sdScale: 100,
})
}, },
}) })
@ -282,6 +316,7 @@ export interface Settings {
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
sdNumSamples: number sdNumSamples: number
sdMatchHistograms: boolean sdMatchHistograms: boolean
sdScale: number
// For OpenCV2 // For OpenCV2
cv2Radius: number cv2Radius: number
@ -409,6 +444,7 @@ export const settingStateDefault: Settings = {
sdSeedFixed: true, sdSeedFixed: true,
sdNumSamples: 1, sdNumSamples: 1,
sdMatchHistograms: false, sdMatchHistograms: false,
sdScale: 100,
// CV2 // CV2
cv2Radius: 5, cv2Radius: 5,

View File

@ -1,11 +1,16 @@
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
import os import os
from functools import lru_cache from cachetools import TTLCache, cached
import cv2
import time
from io import BytesIO from io import BytesIO
from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, PngImagePlugin from PIL import Image, ImageOps, PngImagePlugin
LARGE_ENOUGH_NUMBER = 100 LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2)
from .storage_backends import FilesystemStorageBackend from .storage_backends import FilesystemStorageBackend
from .utils import aspect_to_string, generate_filename, glob_img from .utils import aspect_to_string, generate_filename, glob_img
@ -18,6 +23,7 @@ class FileManager:
self._default_root_url = "/" self._default_root_url = "/"
self._default_thumbnail_root_url = "/" self._default_thumbnail_root_url = "/"
self._default_format = "JPEG" self._default_format = "JPEG"
self.output_dir: Path = None
if app is not None: if app is not None:
self.init_app(app) self.init_app(app)
@ -41,6 +47,16 @@ class FileManager:
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url) app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format) app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
def save_to_output_directory(self, image: np.ndarray, filename: str):
fp = Path(filename)
new_name = fp.stem + f"_{int(time.time())}" + fp.suffix
if image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)
cv2.imwrite(str(self.output_dir / new_name), image)
@property @property
def root_directory(self): def root_directory(self):
path = self.app.config["THUMBNAIL_MEDIA_ROOT"] path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
@ -64,7 +80,7 @@ class FileManager:
return self.app.config["THUMBNAIL_MEDIA_URL"] return self.app.config["THUMBNAIL_MEDIA_URL"]
@property @property
@lru_cache() @cached(cache=TTLCache(maxsize=1024, ttl=30))
def media_names(self): def media_names(self):
names = sorted([it.name for it in glob_img(self.root_directory)]) names = sorted([it.name for it in glob_img(self.root_directory)])
res = [] res = []

View File

@ -6,6 +6,9 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from loguru import logger
from lama_cleaner.helper import resize_max_size
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
@ -15,15 +18,20 @@ class PaintByExample(InpaintModel):
min_size = 512 min_size = 512
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs['no_half'] fp16 = not kwargs.get('no_half', False)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available() use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
self.model = DiffusionPipeline.from_pretrained( self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example", "Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
**model_kwargs
) )
self.model.enable_attention_slicing()
self.model = self.model.to(device) self.model = self.model.to(device)
self.model.enable_attention_slicing()
# TODO: gpu_id
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
self.model.enable_sequential_cpu_offload(gpu_id=0)
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
@ -49,6 +57,25 @@ class PaintByExample(InpaintModel):
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output return output
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do paint_by_example: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
@torch.no_grad() @torch.no_grad()
def __call__(self, image, mask, config: Config): def __call__(self, image, mask, config: Config):
""" """
@ -58,11 +85,11 @@ class PaintByExample(InpaintModel):
""" """
if config.use_croper: if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._pad_forward(crop_img, crop_mask, config) crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1] inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image inpaint_result[t:b, l:r, :] = crop_image
else: else:
inpaint_result = self._pad_forward(image, mask, config) inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result return inpaint_result

View File

@ -8,7 +8,9 @@ from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerD
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from loguru import logger from loguru import logger
from lama_cleaner.helper import resize_max_size
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.schema import Config, SDSampler from lama_cleaner.schema import Config, SDSampler
@ -18,6 +20,8 @@ class CPUTextEncoderWrapper:
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True) self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
del text_encoder
torch_gc()
def __call__(self, x, **kwargs): def __call__(self, x, **kwargs):
input_device = x.device input_device = x.device
@ -30,9 +34,9 @@ class SD(InpaintModel):
def init_model(self, device: torch.device, **kwargs): def init_model(self, device: torch.device, **kwargs):
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
fp16 = not kwargs['no_half'] fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs['sd_run_local']} model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
if kwargs['sd_disable_nsfw']: if kwargs['sd_disable_nsfw']:
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict( model_kwargs.update(dict(
@ -48,19 +52,43 @@ class SD(InpaintModel):
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs **model_kwargs
) )
self.model = self.model.to(device)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get('sd_enable_xformers', False): if kwargs.get('sd_enable_xformers', False):
self.model.enable_xformers_memory_efficient_attention() self.model.enable_xformers_memory_efficient_attention()
self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']: if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
logger.info("Run Stable Diffusion TextEncoder on CPU") # TODO: gpu_id
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype) self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
def _scaled_pad_forward(self, image, mask, config: Config):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
logger.info(
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
)
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
# only paste masked area result
inpaint_result = cv2.resize(
inpaint_result,
(origin_size[1], origin_size[0]),
interpolation=cv2.INTER_CUBIC,
)
original_pixel_indices = mask < 127
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
return inpaint_result
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -126,11 +154,11 @@ class SD(InpaintModel):
# boxes = boxes_from_mask(mask) # boxes = boxes_from_mask(mask)
if config.use_croper: if config.use_croper:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config) crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._pad_forward(crop_img, crop_mask, config) crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1] inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image inpaint_result[t:b, l:r, :] = crop_image
else: else:
inpaint_result = self._pad_forward(image, mask, config) inpaint_result = self._scaled_pad_forward(image, mask, config)
return inpaint_result return inpaint_result

View File

@ -707,3 +707,9 @@ class Conv2dLayer(torch.nn.Module):
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
return out return out
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -15,7 +15,9 @@ def parse_args():
default="lama", default="lama",
choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"], choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga", "sd2", "paint_by_example"],
) )
parser.add_argument("--no-half", action="store_true", help="SD/PaintByExample model no half precision") parser.add_argument("--no-half", action="store_true", help="sd/paint_by_example model no half precision")
parser.add_argument("--cpu-offload", action="store_true",
help="sd/paint_by_example model, offloads all models to CPU, significantly reducing vRAM usage.")
parser.add_argument( parser.add_argument(
"--hf_access_token", "--hf_access_token",
default="", default="",
@ -34,7 +36,12 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sd-run-local", "--sd-run-local",
action="store_true", action="store_true",
help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token", help="After first time Stable Diffusion model downloaded, you can add this arg and remove --hf_access_token.",
)
parser.add_argument(
"--local-files-only",
action="store_true",
help="sd/paint_by_example model. Use local files only, not connect to huggingface server",
) )
parser.add_argument( parser.add_argument(
"--sd-enable-xformers", "--sd-enable-xformers",
@ -80,7 +87,7 @@ def parse_args():
if not output_dir.is_dir(): if not output_dir.is_dir():
parser.error(f"invalid --output-dir: {output_dir} is not a directory") parser.error(f"invalid --output-dir: {output_dir} is not a directory")
if args.model == 'sd1.5' and not args.sd_run_local: if args.model == 'sd1.5' and not (args.sd_run_local or args.local_files_only):
if not args.hf_access_token.startswith("hf_"): if not args.hf_access_token.startswith("hf_"):
parser.error( parser.error(
f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens" f"sd(stable-diffusion) model requires huggingface access token. Check how to get token from: https://huggingface.co/docs/hub/security-tokens"

View File

@ -58,6 +58,9 @@ class Config(BaseModel):
croper_height: int = None croper_height: int = None
croper_width: int = None croper_width: int = None
# Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
# Used by sd models and paint_by_example model
sd_scale: float = 1.0
# Blur the edge of mask area. The higher the number the smoother blend with the original image # Blur the edge of mask area. The higher the number the smoother blend with the original image
sd_mask_blur: int = 0 sd_mask_blur: int = 0
# Ignore this value, it's useless for inpainting # Ignore this value, it's useless for inpainting

View File

@ -31,7 +31,6 @@ except:
pass pass
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
from flask_caching import Cache
# Disable ability for Flask to display warning about using a development server in a production environment. # Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
@ -67,12 +66,9 @@ class NoFlaskwebgui(logging.Filter):
logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
cache = Cache(config={'CACHE_TYPE': 'SimpleCache'}, with_jinja2_ext=False)
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition"]) CORS(app, expose_headers=["Content-Disposition"])
cache.init_app(app)
model: ModelManager = None model: ModelManager = None
thumb = FileManager(app) thumb = FileManager(app)
@ -96,6 +92,16 @@ def diffuser_callback(i, t, latents):
# socketio.emit('diffusion_step', {'diffusion_step': step}) # socketio.emit('diffusion_step', {'diffusion_step': step})
@app.route("/save_image", methods=["POST"])
def save_image():
# all image in output directory
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
thumb.save_to_output_directory(image, request.form["filename"])
return 'ok', 200
@app.route("/medias") @app.route("/medias")
def medias(): def medias():
# all images in input folder # all images in input folder
@ -172,6 +178,7 @@ def process():
croper_y=form["croperY"], croper_y=form["croperY"],
croper_height=form["croperHeight"], croper_height=form["croperHeight"],
croper_width=form["croperWidth"], croper_width=form["croperWidth"],
sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"], sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"], sd_strength=form["sdStrength"],
sd_steps=form["sdSteps"], sd_steps=form["sdSteps"],
@ -345,6 +352,7 @@ def main(args):
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'thumbnails') app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'thumbnails')
is_enable_file_manager = True is_enable_file_manager = True
thumb.output_dir = Path(args.output_dir)
else: else:
input_image_path = args.input input_image_path = args.input
@ -356,6 +364,8 @@ def main(args):
sd_disable_nsfw=args.sd_disable_nsfw, sd_disable_nsfw=args.sd_disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder, sd_cpu_textencoder=args.sd_cpu_textencoder,
sd_run_local=args.sd_run_local, sd_run_local=args.sd_run_local,
local_files_only=args.local_files_only,
cpu_offload=args.cpu_offload,
sd_enable_xformers=args.sd_enable_xformers, sd_enable_xformers=args.sd_enable_xformers,
callback=diffuser_callback, callback=diffuser_callback,
) )

View File

@ -38,7 +38,7 @@ def assert_equal(
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example(strategy): def test_paint_by_example(strategy):
model = ModelManager(name="paint_by_example", device=device) model = ModelManager(name="paint_by_example", device=device)
cfg = get_config(strategy, paint_by_example_steps=30 if device == 'cuda' else 1) cfg = get_config(strategy, paint_by_example_steps=30)
assert_equal( assert_equal(
model, model,
cfg, cfg,
@ -46,5 +46,35 @@ def test_paint_by_example(strategy):
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fy=0.9, fy=0.9,
fx=1.3,
)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_sd_scale(strategy):
model = ModelManager(name="paint_by_example", device=device)
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
assert_equal(
model,
cfg,
f"paint_by_example_{strategy.capitalize()}_sdscale.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fy=0.9,
fx=1.3
)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_cpu_offload(strategy):
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True)
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
assert_equal(
model,
cfg,
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fy=0.9,
fx=1.3 fx=1.3
) )

View File

@ -99,8 +99,8 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=True, sd_disable_nsfw=False,
sd_cpu_textencoder=True, sd_cpu_textencoder=False,
callback=callback) callback=callback)
cfg = get_config( cfg = get_config(
strategy, strategy,
@ -121,3 +121,63 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1 fx=1
) )
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
@pytest.mark.parametrize("cpu_textencoder", [False])
@pytest.mark.parametrize("disable_nsfw", [False])
def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
if sd_device == 'cuda' and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
sd_disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
assert_equal(
model,
cfg,
f"runway_sd_{strategy.capitalize()}_{name}_sdscale.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=1.3
)
@pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
if sd_device == 'cuda' and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",
sd_run_local=True,
sd_disable_nsfw=False,
sd_cpu_textencoder=False,
cpu_offload=True)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
cfg.sd_sampler = sampler
name = f"device_{sd_device}_{sampler}"
assert_equal(
model,
cfg,
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -12,3 +12,4 @@ markupsafe==2.0.1
scikit-image==0.19.3 scikit-image==0.19.3
diffusers[torch]==0.10.2 diffusers[torch]==0.10.2
transformers>=4.25.1 transformers>=4.25.1
cachetools==5.2.0