From f38be37f8c32069c08d3cb73436e9e58b0d86997 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Jan 2024 14:34:36 +0800 Subject: [PATCH] get samplers from backend --- lama_cleaner/api.py | 6 ++ lama_cleaner/download.py | 5 +- lama_cleaner/model/kandinsky.py | 4 +- lama_cleaner/model/utils.py | 59 +++++++++++------ lama_cleaner/schema.py | 30 ++++++--- lama_cleaner/tests/test_outpainting.py | 6 +- lama_cleaner/tests/test_sd_model.py | 63 ++++++++++--------- lama_cleaner/tests/utils.py | 8 ++- requirements.txt | 14 ++--- scripts/tool.py | 14 ++--- .../components/SidePanel/DiffusionOptions.tsx | 10 +-- web_app/src/lib/api.ts | 5 ++ web_app/src/lib/states.ts | 6 +- web_app/src/lib/types.ts | 12 +--- 14 files changed, 141 insertions(+), 101 deletions(-) diff --git a/lama_cleaner/api.py b/lama_cleaner/api.py index 4f326a1..a8fee6d 100644 --- a/lama_cleaner/api.py +++ b/lama_cleaner/api.py @@ -37,6 +37,7 @@ from lama_cleaner.schema import ( SwitchModelRequest, InpaintRequest, RunPluginRequest, + SDSampler, ) from lama_cleaner.file_manager import FileManager @@ -129,6 +130,7 @@ class Api: self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"]) + self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") # fmt: on @@ -156,6 +158,7 @@ class Api: controlnetMethod=self.model_manager.controlnet_method, disableModelSwitch=self.config.disable_model_switch, isDesktop=self.config.gui, + samplers=self.api_samplers(), ) def api_input_image(self) -> FileResponse: @@ -237,6 +240,9 @@ class Api: media_type=f"image/{ext}", ) + def api_samplers(self) -> List[str]: + return [member.value for member in SDSampler.__members__.values()] + def launch(self): self.app.include_router(self.router) uvicorn.run( diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py index 6e90f4c..ca3781d 100644 --- a/lama_cleaner/download.py +++ b/lama_cleaner/download.py @@ -2,6 +2,7 @@ import json import os from typing import List +from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger from pathlib import Path @@ -101,13 +102,11 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]: def scan_models() -> List[ModelInfo]: - from diffusers.utils import DIFFUSERS_CACHE - model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR) available_models = [] available_models.extend(scan_inpaint_models(model_dir)) available_models.extend(scan_single_file_diffusion_models(model_dir)) - cache_dir = Path(DIFFUSERS_CACHE) + cache_dir = Path(HF_HUB_CACHE) # logger.info(f"Scanning diffusers models in {cache_dir}") diffusers_model_names = [] for it in cache_dir.glob("**/*/model_index.json"): diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py index 965aa72..e5467e0 100644 --- a/lama_cleaner/model/kandinsky.py +++ b/lama_cleaner/model/kandinsky.py @@ -35,9 +35,7 @@ class Kandinsky(DiffusionInpaintModel): mask: [H, W, 1] 255 means area to repaint return: BGR IMAGE """ - scheduler_config = self.model.scheduler.config - scheduler = get_scheduler(config.sd_sampler, scheduler_config) - self.model.scheduler = scheduler + self.set_scheduler(config) generator = torch.manual_seed(config.sd_seed) mask = mask.astype(np.float32) / 255 diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py index 18fcf6c..535bacf 100644 --- a/lama_cleaner/model/utils.py +++ b/lama_cleaner/model/utils.py @@ -1,3 +1,4 @@ +import copy import gc import math import random @@ -18,10 +19,13 @@ from diffusers import ( DPMSolverMultistepScheduler, UniPCMultistepScheduler, LCMScheduler, + DPMSolverSinglestepScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, ) -from huggingface_hub.utils import RevisionNotFoundError +from diffusers.configuration_utils import FrozenDict from loguru import logger -from requests import HTTPError from lama_cleaner.schema import SDSampler from torch import conv2d, conv_transpose2d @@ -930,22 +934,41 @@ def set_seed(seed: int): def get_scheduler(sd_sampler, scheduler_config): - if sd_sampler == SDSampler.ddim: - return DDIMScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.pndm: - return PNDMScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.k_lms: - return LMSDiscreteScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.k_euler: - return EulerDiscreteScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.k_euler_a: - return EulerAncestralDiscreteScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.dpm_plus_plus: - return DPMSolverMultistepScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.uni_pc: - return UniPCMultistepScheduler.from_config(scheduler_config) - elif sd_sampler == SDSampler.lcm: - return LCMScheduler.from_config(scheduler_config) + # https://github.com/huggingface/diffusers/issues/4167 + keys_to_pop = ["use_karras_sigmas", "algorithm_type"] + scheduler_config = dict(scheduler_config) + for it in keys_to_pop: + scheduler_config.pop(it, None) + + # fmt: off + samplers = { + SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler], + SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)], + SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")], + SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)], + SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler], + SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)], + SDSampler.dpm2: [KDPM2DiscreteScheduler], + SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)], + SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler], + SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)], + SDSampler.euler: [EulerDiscreteScheduler], + SDSampler.euler_a: [EulerAncestralDiscreteScheduler], + SDSampler.heun: [HeunDiscreteScheduler], + SDSampler.lms: [LMSDiscreteScheduler], + SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)], + SDSampler.ddim: [DDIMScheduler], + SDSampler.pndm: [PNDMScheduler], + SDSampler.uni_pc: [UniPCMultistepScheduler], + SDSampler.lcm: [LCMScheduler], + } + # fmt: on + if sd_sampler in samplers: + if len(samplers[sd_sampler]) == 2: + scheduler_cls, kwargs = samplers[sd_sampler] + else: + scheduler_cls, kwargs = samplers[sd_sampler][0], {} + return scheduler_cls.from_config(scheduler_config, **kwargs) else: raise ValueError(sd_sampler) diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index db67f53..c347796 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -40,15 +40,26 @@ class LDMSampler(str, Enum): class SDSampler(str, Enum): - ddim = "ddim" - pndm = "pndm" - k_lms = "k_lms" - k_euler = "k_euler" - k_euler_a = "k_euler_a" - dpm_plus_plus = "dpm++" - uni_pc = "uni_pc" + dpm_plus_plus_2m = "DPM++ 2M" + dpm_plus_plus_2m_karras = "DPM++ 2M Karras" + dpm_plus_plus_2m_sde = "DPM++ 2M SDE" + dpm_plus_plus_2m_sde_karras = "DPM++ 2M SDE Karras" + dpm_plus_plus_sde = "DPM++ SDE" + dpm_plus_plus_sde_karras = "DPM++ SDE Karras" + dpm2 = "DPM2" + dpm2_karras = "DPM2 Karras" + dpm2_a = "DPM2 a" + dpm2_a_karras = "DPM2 a Karras" + euler = "Euler" + euler_a = "Euler a" + heun = "Heun" + lms = "LMS" + lms_karras = "LMS Karras" - lcm = "lcm" + ddim = "DDIM" + pndm = "PNDM" + uni_pc = "UniPC" + lcm = "LCM" class FREEUConfig(BaseModel): @@ -143,7 +154,7 @@ class InpaintRequest(BaseModel): le=1.0, ) sd_mask_blur: int = Field( - 33, + 11, description="Blur the edge of mask area. The higher the number the smoother blend with the original image", ) sd_strength: float = Field( @@ -268,6 +279,7 @@ class ServerConfigResponse(BaseModel): controlnetMethod: Optional[str] disableModelSwitch: bool isDesktop: bool + samplers: List[str] class SwitchModelRequest(BaseModel): diff --git a/lama_cleaner/tests/test_outpainting.py b/lama_cleaner/tests/test_outpainting.py index c036e85..b47e1ee 100644 --- a/lama_cleaner/tests/test_outpainting.py +++ b/lama_cleaner/tests/test_outpainting.py @@ -49,7 +49,7 @@ def test_outpainting(name, device, rect): extender_width=rect[2], extender_height=rect[3], sd_guidance_scale=8.0, - sd_sampler=SDSampler.dpm_plus_plus, + sd_sampler=SDSampler.dpm_plus_plus_2m, ) assert_equal( @@ -92,7 +92,7 @@ def test_kandinsky_outpainting(name, device, rect): extender_width=rect[2], extender_height=rect[3], sd_guidance_scale=7, - sd_sampler=SDSampler.dpm_plus_plus, + sd_sampler=SDSampler.dpm_plus_plus_2m, ) assert_equal( @@ -136,7 +136,7 @@ def test_powerpaint_outpainting(name, device, rect): extender_width=rect[2], extender_height=rect[3], sd_guidance_scale=8.0, - sd_sampler=SDSampler.dpm_plus_plus, + sd_sampler=SDSampler.dpm_plus_plus_2m, powerpaint_task="outpainting", ) diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py index 90fdc29..b94a3ef 100644 --- a/lama_cleaner/tests/test_sd_model.py +++ b/lama_cleaner/tests/test_sd_model.py @@ -1,5 +1,7 @@ import os +from loguru import logger + from lama_cleaner.tests.utils import check_device, get_config, assert_equal os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -17,21 +19,7 @@ save_dir.mkdir(exist_ok=True, parents=True) @pytest.mark.parametrize("device", ["cuda", "mps"]) -@pytest.mark.parametrize( - "sampler", - [ - SDSampler.ddim, - SDSampler.pndm, - SDSampler.k_lms, - SDSampler.k_euler, - SDSampler.k_euler_a, - SDSampler.lcm, - ], -) -def test_runway_sd_1_5_all_samplers( - device, - sampler, -): +def test_runway_sd_1_5_all_samplers(device): sd_steps = check_device(device) model = ModelManager( name="runwayml/stable-diffusion-inpainting", @@ -39,22 +27,37 @@ def test_runway_sd_1_5_all_samplers( disable_nsfw=True, sd_cpu_textencoder=False, ) - cfg = get_config( - strategy=HDStrategy.ORIGINAL, - prompt="a fox sitting on a bench", - sd_steps=sd_steps, - ) - cfg.sd_sampler = sampler - name = f"device_{device}_{sampler}" + all_samplers = [member.value for member in SDSampler.__members__.values()] + print(all_samplers) + for sampler in all_samplers: + print(f"Testing sampler {sampler}") + if ( + sampler + in [SDSampler.dpm2_karras, SDSampler.dpm2_a_karras, SDSampler.lms_karras] + and device == "mps" + ): + # diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it + logger.warning( + "skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead." + ) + continue + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_sampler=sampler, + ) - assert_equal( - model, - cfg, - f"runway_sd_{name}.png", - img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", - mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", - ) + name = f"device_{device}_{sampler}" + + assert_equal( + model, + cfg, + f"runway_sd_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) @@ -171,7 +174,7 @@ def test_runway_norm_sd_model(device, strategy, sampler): @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) -@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m]) def test_runway_sd_1_5_cpu_offload(device, strategy, sampler): sd_steps = check_device(device) model = ModelManager( diff --git a/lama_cleaner/tests/utils.py b/lama_cleaner/tests/utils.py index 00ebdf1..6786080 100644 --- a/lama_cleaner/tests/utils.py +++ b/lama_cleaner/tests/utils.py @@ -3,7 +3,9 @@ import cv2 import pytest import torch +from lama_cleaner.helper import encode_pil_to_base64 from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler +from PIL import Image current_dir = Path(__file__).parent.absolute().resolve() save_dir = current_dir / "result" @@ -21,7 +23,7 @@ def check_device(device: str) -> int: def assert_equal( model, - config, + config: InpaintRequest, gt_name, fx: float = 1, fy: float = 1, @@ -29,6 +31,8 @@ def assert_equal( mask_p=current_dir / "mask.png", ): img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) + config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0] + config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0] print(f"Input image shape: {img.shape}") res = model(img, mask, config) ok = cv2.imwrite( @@ -72,4 +76,4 @@ def get_config(**kwargs): hd_strategy_resize_limit=200, ) data.update(**kwargs) - return InpaintRequest(**data) + return InpaintRequest(image="", mask="", **data) diff --git a/requirements.txt b/requirements.txt index 461b2e3..0c4d536 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,18 @@ torch>=2.0.0 -typer opencv-python +diffusers==0.25.0 +transformers==4.34.1 +safetensors +controlnet-aux==0.0.3 fastapi==0.108.0 python-multipart simple-websocket -flask_cors flaskwebgui==0.3.5 +typer pydantic rich loguru yacs -diffusers==0.23.0 -transformers==4.34.1 gradio piexif==1.1.3 -safetensors -omegaconf -controlnet-aux==0.0.3 -werkzeug==2.2.2 +omegaconf \ No newline at end of file diff --git a/scripts/tool.py b/scripts/tool.py index 64a2c07..5b2018b 100644 --- a/scripts/tool.py +++ b/scripts/tool.py @@ -5,7 +5,7 @@ from typing import Dict, List, Union import torch from diffusers.utils import is_safetensors_available - +from huggingface_hub.constants import HF_HUB_CACHE if is_safetensors_available(): import safetensors.torch @@ -16,7 +16,6 @@ from diffusers import DiffusionPipeline, __version__ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import ( CONFIG_NAME, - DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, ) @@ -96,7 +95,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): """ # Default kwargs from DiffusionPipeline - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -246,7 +245,6 @@ class CheckpointMergerPipeline(DiffusionPipeline): print(f"Skipping {attr}: not present in 2nd or 3d model") continue - try: module = getattr(final_pipe, attr) if isinstance( @@ -267,7 +265,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): else torch.load(checkpoint_path_1, map_location="cpu") ) - if attr in ['vae', 'text_encoder']: + if attr in ["vae", "text_encoder"]: print(f"Direct use theta1 {attr}: {checkpoint_path_1}") update_theta_0(theta_1) del theta_1 @@ -348,7 +346,7 @@ pipe = CheckpointMergerPipeline.from_pretrained("runwayml/stable-diffusion-inpai merged_pipe = pipe.merge( [ "runwayml/stable-diffusion-inpainting", - #"SG161222/Realistic_Vision_V1.4", + # "SG161222/Realistic_Vision_V1.4", "dreamlike-art/dreamlike-diffusion-1.0", "runwayml/stable-diffusion-v1-5", ], @@ -358,4 +356,6 @@ merged_pipe = pipe.merge( ) merged_pipe = merged_pipe.to(torch.float16) -merged_pipe.save_pretrained("dreamlike-diffusion-1.0-inpainting", safe_serialization=True) +merged_pipe.save_pretrained( + "dreamlike-diffusion-1.0-inpainting", safe_serialization=True +) diff --git a/web_app/src/components/SidePanel/DiffusionOptions.tsx b/web_app/src/components/SidePanel/DiffusionOptions.tsx index a2c9d1d..557f4ef 100644 --- a/web_app/src/components/SidePanel/DiffusionOptions.tsx +++ b/web_app/src/components/SidePanel/DiffusionOptions.tsx @@ -11,7 +11,7 @@ import { SelectValue, } from "../ui/select" import { Textarea } from "../ui/textarea" -import { ExtenderDirection, PowerPaintTask, SDSampler } from "@/lib/types" +import { ExtenderDirection, PowerPaintTask } from "@/lib/types" import { Separator } from "../ui/separator" import { Button, ImageUploadButton } from "../ui/button" import { Slider } from "../ui/slider" @@ -42,6 +42,7 @@ const ExtenderButton = ({ const DiffusionOptions = () => { const [ + samplers, settings, paintByExampleFile, isProcessing, @@ -51,6 +52,7 @@ const DiffusionOptions = () => { updateExtenderByBuiltIn, updateExtenderDirection, ] = useStore((state) => [ + state.serverConfig.samplers, state.settings, state.paintByExampleFile, state.getIsProcessing(), @@ -652,16 +654,16 @@ const DiffusionOptions = () => {