get samplers from backend
This commit is contained in:
parent
a2fd5bb3ea
commit
f38be37f8c
@ -37,6 +37,7 @@ from lama_cleaner.schema import (
|
|||||||
SwitchModelRequest,
|
SwitchModelRequest,
|
||||||
InpaintRequest,
|
InpaintRequest,
|
||||||
RunPluginRequest,
|
RunPluginRequest,
|
||||||
|
SDSampler,
|
||||||
)
|
)
|
||||||
from lama_cleaner.file_manager import FileManager
|
from lama_cleaner.file_manager import FileManager
|
||||||
|
|
||||||
@ -129,6 +130,7 @@ class Api:
|
|||||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||||
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
|
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
|
||||||
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -156,6 +158,7 @@ class Api:
|
|||||||
controlnetMethod=self.model_manager.controlnet_method,
|
controlnetMethod=self.model_manager.controlnet_method,
|
||||||
disableModelSwitch=self.config.disable_model_switch,
|
disableModelSwitch=self.config.disable_model_switch,
|
||||||
isDesktop=self.config.gui,
|
isDesktop=self.config.gui,
|
||||||
|
samplers=self.api_samplers(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_input_image(self) -> FileResponse:
|
def api_input_image(self) -> FileResponse:
|
||||||
@ -237,6 +240,9 @@ class Api:
|
|||||||
media_type=f"image/{ext}",
|
media_type=f"image/{ext}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def api_samplers(self) -> List[str]:
|
||||||
|
return [member.value for member in SDSampler.__members__.values()]
|
||||||
|
|
||||||
def launch(self):
|
def launch(self):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -101,13 +102,11 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
|
|||||||
|
|
||||||
|
|
||||||
def scan_models() -> List[ModelInfo]:
|
def scan_models() -> List[ModelInfo]:
|
||||||
from diffusers.utils import DIFFUSERS_CACHE
|
|
||||||
|
|
||||||
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
||||||
available_models = []
|
available_models = []
|
||||||
available_models.extend(scan_inpaint_models(model_dir))
|
available_models.extend(scan_inpaint_models(model_dir))
|
||||||
available_models.extend(scan_single_file_diffusion_models(model_dir))
|
available_models.extend(scan_single_file_diffusion_models(model_dir))
|
||||||
cache_dir = Path(DIFFUSERS_CACHE)
|
cache_dir = Path(HF_HUB_CACHE)
|
||||||
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
||||||
diffusers_model_names = []
|
diffusers_model_names = []
|
||||||
for it in cache_dir.glob("**/*/model_index.json"):
|
for it in cache_dir.glob("**/*/model_index.json"):
|
||||||
|
@ -35,9 +35,7 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
scheduler_config = self.model.scheduler.config
|
self.set_scheduler(config)
|
||||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
|
||||||
self.model.scheduler = scheduler
|
|
||||||
|
|
||||||
generator = torch.manual_seed(config.sd_seed)
|
generator = torch.manual_seed(config.sd_seed)
|
||||||
mask = mask.astype(np.float32) / 255
|
mask = mask.astype(np.float32) / 255
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
@ -18,10 +19,13 @@ from diffusers import (
|
|||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
LCMScheduler,
|
LCMScheduler,
|
||||||
|
DPMSolverSinglestepScheduler,
|
||||||
|
KDPM2DiscreteScheduler,
|
||||||
|
KDPM2AncestralDiscreteScheduler,
|
||||||
|
HeunDiscreteScheduler,
|
||||||
)
|
)
|
||||||
from huggingface_hub.utils import RevisionNotFoundError
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from requests import HTTPError
|
|
||||||
|
|
||||||
from lama_cleaner.schema import SDSampler
|
from lama_cleaner.schema import SDSampler
|
||||||
from torch import conv2d, conv_transpose2d
|
from torch import conv2d, conv_transpose2d
|
||||||
@ -930,22 +934,41 @@ def set_seed(seed: int):
|
|||||||
|
|
||||||
|
|
||||||
def get_scheduler(sd_sampler, scheduler_config):
|
def get_scheduler(sd_sampler, scheduler_config):
|
||||||
if sd_sampler == SDSampler.ddim:
|
# https://github.com/huggingface/diffusers/issues/4167
|
||||||
return DDIMScheduler.from_config(scheduler_config)
|
keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
|
||||||
elif sd_sampler == SDSampler.pndm:
|
scheduler_config = dict(scheduler_config)
|
||||||
return PNDMScheduler.from_config(scheduler_config)
|
for it in keys_to_pop:
|
||||||
elif sd_sampler == SDSampler.k_lms:
|
scheduler_config.pop(it, None)
|
||||||
return LMSDiscreteScheduler.from_config(scheduler_config)
|
|
||||||
elif sd_sampler == SDSampler.k_euler:
|
# fmt: off
|
||||||
return EulerDiscreteScheduler.from_config(scheduler_config)
|
samplers = {
|
||||||
elif sd_sampler == SDSampler.k_euler_a:
|
SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
|
||||||
return EulerAncestralDiscreteScheduler.from_config(scheduler_config)
|
SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
|
||||||
elif sd_sampler == SDSampler.dpm_plus_plus:
|
SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
|
||||||
return DPMSolverMultistepScheduler.from_config(scheduler_config)
|
SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
|
||||||
elif sd_sampler == SDSampler.uni_pc:
|
SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
|
||||||
return UniPCMultistepScheduler.from_config(scheduler_config)
|
SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
|
||||||
elif sd_sampler == SDSampler.lcm:
|
SDSampler.dpm2: [KDPM2DiscreteScheduler],
|
||||||
return LCMScheduler.from_config(scheduler_config)
|
SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
|
||||||
|
SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
|
||||||
|
SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
|
||||||
|
SDSampler.euler: [EulerDiscreteScheduler],
|
||||||
|
SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
|
||||||
|
SDSampler.heun: [HeunDiscreteScheduler],
|
||||||
|
SDSampler.lms: [LMSDiscreteScheduler],
|
||||||
|
SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
|
||||||
|
SDSampler.ddim: [DDIMScheduler],
|
||||||
|
SDSampler.pndm: [PNDMScheduler],
|
||||||
|
SDSampler.uni_pc: [UniPCMultistepScheduler],
|
||||||
|
SDSampler.lcm: [LCMScheduler],
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
if sd_sampler in samplers:
|
||||||
|
if len(samplers[sd_sampler]) == 2:
|
||||||
|
scheduler_cls, kwargs = samplers[sd_sampler]
|
||||||
|
else:
|
||||||
|
scheduler_cls, kwargs = samplers[sd_sampler][0], {}
|
||||||
|
return scheduler_cls.from_config(scheduler_config, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(sd_sampler)
|
raise ValueError(sd_sampler)
|
||||||
|
|
||||||
|
@ -40,15 +40,26 @@ class LDMSampler(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class SDSampler(str, Enum):
|
class SDSampler(str, Enum):
|
||||||
ddim = "ddim"
|
dpm_plus_plus_2m = "DPM++ 2M"
|
||||||
pndm = "pndm"
|
dpm_plus_plus_2m_karras = "DPM++ 2M Karras"
|
||||||
k_lms = "k_lms"
|
dpm_plus_plus_2m_sde = "DPM++ 2M SDE"
|
||||||
k_euler = "k_euler"
|
dpm_plus_plus_2m_sde_karras = "DPM++ 2M SDE Karras"
|
||||||
k_euler_a = "k_euler_a"
|
dpm_plus_plus_sde = "DPM++ SDE"
|
||||||
dpm_plus_plus = "dpm++"
|
dpm_plus_plus_sde_karras = "DPM++ SDE Karras"
|
||||||
uni_pc = "uni_pc"
|
dpm2 = "DPM2"
|
||||||
|
dpm2_karras = "DPM2 Karras"
|
||||||
|
dpm2_a = "DPM2 a"
|
||||||
|
dpm2_a_karras = "DPM2 a Karras"
|
||||||
|
euler = "Euler"
|
||||||
|
euler_a = "Euler a"
|
||||||
|
heun = "Heun"
|
||||||
|
lms = "LMS"
|
||||||
|
lms_karras = "LMS Karras"
|
||||||
|
|
||||||
lcm = "lcm"
|
ddim = "DDIM"
|
||||||
|
pndm = "PNDM"
|
||||||
|
uni_pc = "UniPC"
|
||||||
|
lcm = "LCM"
|
||||||
|
|
||||||
|
|
||||||
class FREEUConfig(BaseModel):
|
class FREEUConfig(BaseModel):
|
||||||
@ -143,7 +154,7 @@ class InpaintRequest(BaseModel):
|
|||||||
le=1.0,
|
le=1.0,
|
||||||
)
|
)
|
||||||
sd_mask_blur: int = Field(
|
sd_mask_blur: int = Field(
|
||||||
33,
|
11,
|
||||||
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
|
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
|
||||||
)
|
)
|
||||||
sd_strength: float = Field(
|
sd_strength: float = Field(
|
||||||
@ -268,6 +279,7 @@ class ServerConfigResponse(BaseModel):
|
|||||||
controlnetMethod: Optional[str]
|
controlnetMethod: Optional[str]
|
||||||
disableModelSwitch: bool
|
disableModelSwitch: bool
|
||||||
isDesktop: bool
|
isDesktop: bool
|
||||||
|
samplers: List[str]
|
||||||
|
|
||||||
|
|
||||||
class SwitchModelRequest(BaseModel):
|
class SwitchModelRequest(BaseModel):
|
||||||
|
@ -49,7 +49,7 @@ def test_outpainting(name, device, rect):
|
|||||||
extender_width=rect[2],
|
extender_width=rect[2],
|
||||||
extender_height=rect[3],
|
extender_height=rect[3],
|
||||||
sd_guidance_scale=8.0,
|
sd_guidance_scale=8.0,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus_2m,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
@ -92,7 +92,7 @@ def test_kandinsky_outpainting(name, device, rect):
|
|||||||
extender_width=rect[2],
|
extender_width=rect[2],
|
||||||
extender_height=rect[3],
|
extender_height=rect[3],
|
||||||
sd_guidance_scale=7,
|
sd_guidance_scale=7,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus_2m,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
@ -136,7 +136,7 @@ def test_powerpaint_outpainting(name, device, rect):
|
|||||||
extender_width=rect[2],
|
extender_width=rect[2],
|
||||||
extender_height=rect[3],
|
extender_height=rect[3],
|
||||||
sd_guidance_scale=8.0,
|
sd_guidance_scale=8.0,
|
||||||
sd_sampler=SDSampler.dpm_plus_plus,
|
sd_sampler=SDSampler.dpm_plus_plus_2m,
|
||||||
powerpaint_task="outpainting",
|
powerpaint_task="outpainting",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.tests.utils import check_device, get_config, assert_equal
|
from lama_cleaner.tests.utils import check_device, get_config, assert_equal
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
@ -17,21 +19,7 @@ save_dir.mkdir(exist_ok=True, parents=True)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize(
|
def test_runway_sd_1_5_all_samplers(device):
|
||||||
"sampler",
|
|
||||||
[
|
|
||||||
SDSampler.ddim,
|
|
||||||
SDSampler.pndm,
|
|
||||||
SDSampler.k_lms,
|
|
||||||
SDSampler.k_euler,
|
|
||||||
SDSampler.k_euler_a,
|
|
||||||
SDSampler.lcm,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_runway_sd_1_5_all_samplers(
|
|
||||||
device,
|
|
||||||
sampler,
|
|
||||||
):
|
|
||||||
sd_steps = check_device(device)
|
sd_steps = check_device(device)
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="runwayml/stable-diffusion-inpainting",
|
name="runwayml/stable-diffusion-inpainting",
|
||||||
@ -39,22 +27,37 @@ def test_runway_sd_1_5_all_samplers(
|
|||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
|
||||||
strategy=HDStrategy.ORIGINAL,
|
|
||||||
prompt="a fox sitting on a bench",
|
|
||||||
sd_steps=sd_steps,
|
|
||||||
)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
name = f"device_{device}_{sampler}"
|
all_samplers = [member.value for member in SDSampler.__members__.values()]
|
||||||
|
print(all_samplers)
|
||||||
|
for sampler in all_samplers:
|
||||||
|
print(f"Testing sampler {sampler}")
|
||||||
|
if (
|
||||||
|
sampler
|
||||||
|
in [SDSampler.dpm2_karras, SDSampler.dpm2_a_karras, SDSampler.lms_karras]
|
||||||
|
and device == "mps"
|
||||||
|
):
|
||||||
|
# diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it
|
||||||
|
logger.warning(
|
||||||
|
"skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt="a fox sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_sampler=sampler,
|
||||||
|
)
|
||||||
|
|
||||||
assert_equal(
|
name = f"device_{device}_{sampler}"
|
||||||
model,
|
|
||||||
cfg,
|
assert_equal(
|
||||||
f"runway_sd_{name}.png",
|
model,
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
cfg,
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
f"runway_sd_{name}.png",
|
||||||
)
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||||
@ -171,7 +174,7 @@ def test_runway_norm_sd_model(device, strategy, sampler):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
|
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
|
||||||
def test_runway_sd_1_5_cpu_offload(device, strategy, sampler):
|
def test_runway_sd_1_5_cpu_offload(device, strategy, sampler):
|
||||||
sd_steps = check_device(device)
|
sd_steps = check_device(device)
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
|
@ -3,7 +3,9 @@ import cv2
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lama_cleaner.helper import encode_pil_to_base64
|
||||||
from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / "result"
|
save_dir = current_dir / "result"
|
||||||
@ -21,7 +23,7 @@ def check_device(device: str) -> int:
|
|||||||
|
|
||||||
def assert_equal(
|
def assert_equal(
|
||||||
model,
|
model,
|
||||||
config,
|
config: InpaintRequest,
|
||||||
gt_name,
|
gt_name,
|
||||||
fx: float = 1,
|
fx: float = 1,
|
||||||
fy: float = 1,
|
fy: float = 1,
|
||||||
@ -29,6 +31,8 @@ def assert_equal(
|
|||||||
mask_p=current_dir / "mask.png",
|
mask_p=current_dir / "mask.png",
|
||||||
):
|
):
|
||||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
|
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
|
||||||
|
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
|
||||||
print(f"Input image shape: {img.shape}")
|
print(f"Input image shape: {img.shape}")
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
ok = cv2.imwrite(
|
ok = cv2.imwrite(
|
||||||
@ -72,4 +76,4 @@ def get_config(**kwargs):
|
|||||||
hd_strategy_resize_limit=200,
|
hd_strategy_resize_limit=200,
|
||||||
)
|
)
|
||||||
data.update(**kwargs)
|
data.update(**kwargs)
|
||||||
return InpaintRequest(**data)
|
return InpaintRequest(image="", mask="", **data)
|
||||||
|
@ -1,20 +1,18 @@
|
|||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
typer
|
|
||||||
opencv-python
|
opencv-python
|
||||||
|
diffusers==0.25.0
|
||||||
|
transformers==4.34.1
|
||||||
|
safetensors
|
||||||
|
controlnet-aux==0.0.3
|
||||||
fastapi==0.108.0
|
fastapi==0.108.0
|
||||||
python-multipart
|
python-multipart
|
||||||
simple-websocket
|
simple-websocket
|
||||||
flask_cors
|
|
||||||
flaskwebgui==0.3.5
|
flaskwebgui==0.3.5
|
||||||
|
typer
|
||||||
pydantic
|
pydantic
|
||||||
rich
|
rich
|
||||||
loguru
|
loguru
|
||||||
yacs
|
yacs
|
||||||
diffusers==0.23.0
|
|
||||||
transformers==4.34.1
|
|
||||||
gradio
|
gradio
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
safetensors
|
omegaconf
|
||||||
omegaconf
|
|
||||||
controlnet-aux==0.0.3
|
|
||||||
werkzeug==2.2.2
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -16,7 +16,6 @@ from diffusers import DiffusionPipeline, __version__
|
|||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import (
|
from diffusers.utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DIFFUSERS_CACHE,
|
|
||||||
ONNX_WEIGHTS_NAME,
|
ONNX_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
)
|
)
|
||||||
@ -96,7 +95,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# Default kwargs from DiffusionPipeline
|
# Default kwargs from DiffusionPipeline
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
@ -246,7 +245,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
|||||||
print(f"Skipping {attr}: not present in 2nd or 3d model")
|
print(f"Skipping {attr}: not present in 2nd or 3d model")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
module = getattr(final_pipe, attr)
|
module = getattr(final_pipe, attr)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
@ -267,7 +265,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
|||||||
else torch.load(checkpoint_path_1, map_location="cpu")
|
else torch.load(checkpoint_path_1, map_location="cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
if attr in ['vae', 'text_encoder']:
|
if attr in ["vae", "text_encoder"]:
|
||||||
print(f"Direct use theta1 {attr}: {checkpoint_path_1}")
|
print(f"Direct use theta1 {attr}: {checkpoint_path_1}")
|
||||||
update_theta_0(theta_1)
|
update_theta_0(theta_1)
|
||||||
del theta_1
|
del theta_1
|
||||||
@ -348,7 +346,7 @@ pipe = CheckpointMergerPipeline.from_pretrained("runwayml/stable-diffusion-inpai
|
|||||||
merged_pipe = pipe.merge(
|
merged_pipe = pipe.merge(
|
||||||
[
|
[
|
||||||
"runwayml/stable-diffusion-inpainting",
|
"runwayml/stable-diffusion-inpainting",
|
||||||
#"SG161222/Realistic_Vision_V1.4",
|
# "SG161222/Realistic_Vision_V1.4",
|
||||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||||
"runwayml/stable-diffusion-v1-5",
|
"runwayml/stable-diffusion-v1-5",
|
||||||
],
|
],
|
||||||
@ -358,4 +356,6 @@ merged_pipe = pipe.merge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
merged_pipe = merged_pipe.to(torch.float16)
|
merged_pipe = merged_pipe.to(torch.float16)
|
||||||
merged_pipe.save_pretrained("dreamlike-diffusion-1.0-inpainting", safe_serialization=True)
|
merged_pipe.save_pretrained(
|
||||||
|
"dreamlike-diffusion-1.0-inpainting", safe_serialization=True
|
||||||
|
)
|
||||||
|
@ -11,7 +11,7 @@ import {
|
|||||||
SelectValue,
|
SelectValue,
|
||||||
} from "../ui/select"
|
} from "../ui/select"
|
||||||
import { Textarea } from "../ui/textarea"
|
import { Textarea } from "../ui/textarea"
|
||||||
import { ExtenderDirection, PowerPaintTask, SDSampler } from "@/lib/types"
|
import { ExtenderDirection, PowerPaintTask } from "@/lib/types"
|
||||||
import { Separator } from "../ui/separator"
|
import { Separator } from "../ui/separator"
|
||||||
import { Button, ImageUploadButton } from "../ui/button"
|
import { Button, ImageUploadButton } from "../ui/button"
|
||||||
import { Slider } from "../ui/slider"
|
import { Slider } from "../ui/slider"
|
||||||
@ -42,6 +42,7 @@ const ExtenderButton = ({
|
|||||||
|
|
||||||
const DiffusionOptions = () => {
|
const DiffusionOptions = () => {
|
||||||
const [
|
const [
|
||||||
|
samplers,
|
||||||
settings,
|
settings,
|
||||||
paintByExampleFile,
|
paintByExampleFile,
|
||||||
isProcessing,
|
isProcessing,
|
||||||
@ -51,6 +52,7 @@ const DiffusionOptions = () => {
|
|||||||
updateExtenderByBuiltIn,
|
updateExtenderByBuiltIn,
|
||||||
updateExtenderDirection,
|
updateExtenderDirection,
|
||||||
] = useStore((state) => [
|
] = useStore((state) => [
|
||||||
|
state.serverConfig.samplers,
|
||||||
state.settings,
|
state.settings,
|
||||||
state.paintByExampleFile,
|
state.paintByExampleFile,
|
||||||
state.getIsProcessing(),
|
state.getIsProcessing(),
|
||||||
@ -652,16 +654,16 @@ const DiffusionOptions = () => {
|
|||||||
<Select
|
<Select
|
||||||
defaultValue={settings.sdSampler}
|
defaultValue={settings.sdSampler}
|
||||||
value={settings.sdSampler}
|
value={settings.sdSampler}
|
||||||
onValueChange={(value: SDSampler) => {
|
onValueChange={(value) => {
|
||||||
updateSettings({ sdSampler: value })
|
updateSettings({ sdSampler: value })
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-[120px]">
|
<SelectTrigger className="w-[180px]">
|
||||||
<SelectValue placeholder="Select sampler" />
|
<SelectValue placeholder="Select sampler" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent align="end">
|
<SelectContent align="end">
|
||||||
<SelectGroup>
|
<SelectGroup>
|
||||||
{Object.values(SDSampler).map((sampler) => (
|
{samplers.map((sampler) => (
|
||||||
<SelectItem key={sampler} value={sampler}>
|
<SelectItem key={sampler} value={sampler}>
|
||||||
{sampler}
|
{sampler}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
|
@ -194,3 +194,8 @@ export async function getGenInfo(file: File): Promise<GenInfo> {
|
|||||||
const res = await api.post(`/gen-info`, fd)
|
const res = await api.post(`/gen-info`, fd)
|
||||||
return res.data
|
return res.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getSamplers(): Promise<string[]> {
|
||||||
|
const res = await api.post("/samplers")
|
||||||
|
return res.data
|
||||||
|
}
|
||||||
|
@ -14,7 +14,6 @@ import {
|
|||||||
PluginParams,
|
PluginParams,
|
||||||
Point,
|
Point,
|
||||||
PowerPaintTask,
|
PowerPaintTask,
|
||||||
SDSampler,
|
|
||||||
ServerConfig,
|
ServerConfig,
|
||||||
Size,
|
Size,
|
||||||
SortBy,
|
SortBy,
|
||||||
@ -85,7 +84,7 @@ export type Settings = {
|
|||||||
sdStrength: number
|
sdStrength: number
|
||||||
sdSteps: number
|
sdSteps: number
|
||||||
sdGuidanceScale: number
|
sdGuidanceScale: number
|
||||||
sdSampler: SDSampler
|
sdSampler: string
|
||||||
sdMatchHistograms: boolean
|
sdMatchHistograms: boolean
|
||||||
sdScale: number
|
sdScale: number
|
||||||
|
|
||||||
@ -274,6 +273,7 @@ const defaultValues: AppState = {
|
|||||||
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
||||||
disableModelSwitch: false,
|
disableModelSwitch: false,
|
||||||
isDesktop: false,
|
isDesktop: false,
|
||||||
|
samplers: ["DPM++ 2M"],
|
||||||
},
|
},
|
||||||
settings: {
|
settings: {
|
||||||
model: {
|
model: {
|
||||||
@ -310,7 +310,7 @@ const defaultValues: AppState = {
|
|||||||
sdStrength: 1.0,
|
sdStrength: 1.0,
|
||||||
sdSteps: 50,
|
sdSteps: 50,
|
||||||
sdGuidanceScale: 7.5,
|
sdGuidanceScale: 7.5,
|
||||||
sdSampler: SDSampler.uni_pc,
|
sdSampler: "DPM++ 2M",
|
||||||
sdMatchHistograms: false,
|
sdMatchHistograms: false,
|
||||||
sdScale: 100,
|
sdScale: 100,
|
||||||
p2pImageGuidanceScale: 1.5,
|
p2pImageGuidanceScale: 1.5,
|
||||||
|
@ -14,6 +14,7 @@ export interface ServerConfig {
|
|||||||
controlnetMethod: string
|
controlnetMethod: string
|
||||||
disableModelSwitch: boolean
|
disableModelSwitch: boolean
|
||||||
isDesktop: boolean
|
isDesktop: boolean
|
||||||
|
samplers: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GenInfo {
|
export interface GenInfo {
|
||||||
@ -82,17 +83,6 @@ export interface Rect {
|
|||||||
height: number
|
height: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum SDSampler {
|
|
||||||
ddim = "ddim",
|
|
||||||
pndm = "pndm",
|
|
||||||
klms = "k_lms",
|
|
||||||
kEuler = "k_euler",
|
|
||||||
kEulerA = "k_euler_a",
|
|
||||||
dpmPlusPlus = "dpm++",
|
|
||||||
uni_pc = "uni_pc",
|
|
||||||
// lcm = "lcm",
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface FreeuConfig {
|
export interface FreeuConfig {
|
||||||
s1: number
|
s1: number
|
||||||
s2: number
|
s2: number
|
||||||
|
Loading…
Reference in New Issue
Block a user