diff --git a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx index c0e1b7c..905979b 100644 --- a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx @@ -230,6 +230,18 @@ function ModelSettingBlock() { 'https://ommer-lab.com/research/latent-diffusion-models/', 'https://github.com/CompVis/stable-diffusion' ) + case AIModel.ANYTHING4: + return renderModelDesc( + 'andite/anything-v4.0', + 'https://huggingface.co/andite/anything-v4.0', + 'https://huggingface.co/andite/anything-v4.0' + ) + case AIModel.REALISTIC_VISION_1_4: + return renderModelDesc( + 'SG161222/Realistic_Vision_V1.4', + 'https://huggingface.co/SG161222/Realistic_Vision_V1.4', + 'https://huggingface.co/SG161222/Realistic_Vision_V1.4' + ) case AIModel.SD2: return renderModelDesc( 'Stable Diffusion 2', diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index 277ffde..dbfb7da 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -10,6 +10,8 @@ export enum AIModel { MAT = 'mat', FCF = 'fcf', SD15 = 'sd1.5', + ANYTHING4 = 'anything4', + REALISTIC_VISION_1_4 = 'realisticVision1.4', SD2 = 'sd2', CV2 = 'cv2', Mange = 'manga', @@ -422,6 +424,20 @@ const defaultHDSettings: ModelsHDSettings = { hdStrategyCropMargin: 128, enabled: false, }, + [AIModel.ANYTHING4]: { + hdStrategy: HDStrategy.ORIGINAL, + hdStrategyResizeLimit: 768, + hdStrategyCropTrigerSize: 512, + hdStrategyCropMargin: 128, + enabled: false, + }, + [AIModel.REALISTIC_VISION_1_4]: { + hdStrategy: HDStrategy.ORIGINAL, + hdStrategyResizeLimit: 768, + hdStrategyCropTrigerSize: 512, + hdStrategyCropMargin: 128, + enabled: false, + }, [AIModel.SD2]: { hdStrategy: HDStrategy.ORIGINAL, hdStrategyResizeLimit: 768, @@ -601,7 +617,12 @@ export const isSDState = selector({ key: 'isSD', get: ({ get }) => { const settings = get(settingState) - return settings.model === AIModel.SD15 || settings.model === AIModel.SD2 + return ( + settings.model === AIModel.SD15 || + settings.model === AIModel.SD2 || + settings.model === AIModel.ANYTHING4 || + settings.model === AIModel.REALISTIC_VISION_1_4 + ) }, }) diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 401937a..af9375c 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -3,6 +3,8 @@ import os MPS_SUPPORT_MODELS = [ "instruct_pix2pix", "sd1.5", + "anything4", + "realisticVision1.4", "sd2", "paint_by_example" ] @@ -15,6 +17,8 @@ AVAILABLE_MODELS = [ "mat", "fcf", "sd1.5", + "anything4", + "realisticVision1.4", "cv2", "manga", "sd2", diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 71618ce..019de9b 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -136,7 +136,7 @@ class SD(DiffusionInpaintModel): callback=self.callback, height=img_h, width=img_w, - generator=torch.manual_seed(config.sd_seed) + generator=torch.manual_seed(config.sd_seed), ).images[0] output = (output * 255).round().astype("uint8") @@ -163,6 +163,16 @@ class SD15(SD): model_id_or_path = "runwayml/stable-diffusion-inpainting" +class Anything4(SD): + name = "anything4" + model_id_or_path = "Sanster/anything-4.0-inpainting" + + +class RealisticVision14(SD): + name = "realisticVision1.4" + model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting" + + class SD2(SD): name = "sd2" model_id_or_path = "stabilityai/stable-diffusion-2-inpainting" diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 7b5720a..bdf0d78 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -9,7 +9,7 @@ from lama_cleaner.model.manga import Manga from lama_cleaner.model.mat import MAT from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix -from lama_cleaner.model.sd import SD15, SD2 +from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.schema import Config @@ -21,6 +21,8 @@ models = { "mat": MAT, "fcf": FcF, "sd1.5": SD15, + Anything4.name: Anything4, + RealisticVision14.name: RealisticVision14, "cv2": OpenCV2, "manga": Manga, "sd2": SD2, diff --git a/scripts/convert_vae_pt_to_diffusers.py b/scripts/convert_vae_pt_to_diffusers.py new file mode 100644 index 0000000..2d79d12 --- /dev/null +++ b/scripts/convert_vae_pt_to_diffusers.py @@ -0,0 +1,231 @@ +import argparse +import io + +import requests +import torch +from omegaconf import OmegaConf + +from diffusers import AutoencoderKL +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + assign_to_checkpoint, + conv_attn_to_linear, + create_vae_diffusers_config, + renew_vae_attention_paths, + renew_vae_resnet_paths, +) + + +def custom_convert_ldm_vae_checkpoint(checkpoint, config): + vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ + "encoder.conv_out.weight" + ] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ + "encoder.norm_out.weight" + ] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ + "encoder.norm_out.bias" + ] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ + "decoder.conv_out.weight" + ] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ + "decoder.norm_out.weight" + ] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ + "decoder.norm_out.bias" + ] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def vae_pt_to_vae_diffuser( + checkpoint_path: str, + output_path: str, +): + # Only support V1 + r = requests.get( + " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + io_obj = io.BytesIO(r.content) + + original_config = OmegaConf.load(io_obj) + image_size = 512 + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint( + checkpoint["state_dict"], vae_config + ) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + vae.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vae_pt_path", + default="/Users/cwq/code/github/lama-cleaner/scripts/anything-v4.0.vae.pt", + type=str, + help="Path to the VAE.pt to convert.", + ) + parser.add_argument( + "--dump_path", + default="diffusion_pytorch_model.bin", + type=str, + help="Path to the VAE.pt to convert.", + ) + + args = parser.parse_args() + + vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) diff --git a/scripts/tool.py b/scripts/tool.py new file mode 100644 index 0000000..64a2c07 --- /dev/null +++ b/scripts/tool.py @@ -0,0 +1,361 @@ +import glob +import os +from typing import Dict, List, Union + +import torch + +from diffusers.utils import is_safetensors_available + + +if is_safetensors_available(): + import safetensors.torch + +from huggingface_hub import snapshot_download + +from diffusers import DiffusionPipeline, __version__ +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + ONNX_WEIGHTS_NAME, + WEIGHTS_NAME, +) + + +class CheckpointMergerPipeline(DiffusionPipeline): + """ + A class that that supports merging diffusion models based on the discussion here: + https://github.com/huggingface/diffusers/issues/877 + + Example usage:- + + pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py") + + merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True) + + merged_pipe.to('cuda') + + prompt = "An astronaut riding a unicycle on Mars" + + results = merged_pipe(prompt) + + ## For more details, see the docstring for the merge method. + + """ + + def __init__(self): + self.register_to_config() + super().__init__() + + def _compare_model_configs(self, dict0, dict1): + if dict0 == dict1: + return True + else: + config0, meta_keys0 = self._remove_meta_keys(dict0) + config1, meta_keys1 = self._remove_meta_keys(dict1) + if config0 == config1: + print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.") + return True + return False + + def _remove_meta_keys(self, config_dict: Dict): + meta_keys = [] + temp_dict = config_dict.copy() + for key in config_dict.keys(): + if key.startswith("_"): + temp_dict.pop(key) + meta_keys.append(key) + return (temp_dict, meta_keys) + + @torch.no_grad() + def merge( + self, + pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], + **kwargs, + ): + """ + Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed + in the argument 'pretrained_model_name_or_path_list' as a list. + + Parameters: + ----------- + pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format. + + **kwargs: + Supports all the default DiffusionPipeline.get_config_dict kwargs viz.. + + cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map. + + alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha + would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 + + interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported. + + force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + + """ + # Default kwargs from DiffusionPipeline + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + device_map = kwargs.pop("device_map", None) + + alpha = kwargs.pop("alpha", 0.5) + interp = kwargs.pop("interp", None) + + print("Received list", pretrained_model_name_or_path_list) + print(f"Combining with alpha={alpha}, interpolation mode={interp}") + + checkpoint_count = len(pretrained_model_name_or_path_list) + # Ignore result from model_index_json comparision of the two checkpoints + force = kwargs.pop("force", False) + + # If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now. + if checkpoint_count > 3 or checkpoint_count < 2: + raise ValueError( + "Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being" + " passed." + ) + + print("Received the right number of checkpoints") + # chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2] + # chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None + + # Validate that the checkpoints can be merged + # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_' + config_dicts = [] + for pretrained_model_name_or_path in pretrained_model_name_or_path_list: + config_dict = DiffusionPipeline.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + config_dicts.append(config_dict) + + comparison_result = True + for idx in range(1, len(config_dicts)): + comparison_result &= self._compare_model_configs( + config_dicts[idx - 1], config_dicts[idx] + ) + if not force and comparison_result is False: + raise ValueError( + "Incompatible checkpoints. Please check model_index.json for the models." + ) + print(config_dicts[0], config_dicts[1]) + print("Compatible model_index.json files found") + # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files. + cached_folders = [] + for pretrained_model_name_or_path, config_dict in zip( + pretrained_model_name_or_path_list, config_dicts + ): + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [ + WEIGHTS_NAME, + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + ONNX_WEIGHTS_NAME, + DiffusionPipeline.config_name, + ] + requested_pipeline_class = config_dict.get("_class_name") + user_agent = { + "diffusers": __version__, + "pipeline_class": requested_pipeline_class, + } + + cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + user_agent=user_agent, + ) + ) + print("Cached Folder", cached_folder) + cached_folders.append(cached_folder) + + # Step 3:- + # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place + final_pipe = DiffusionPipeline.from_pretrained( + cached_folders[0], torch_dtype=torch_dtype, device_map=device_map + ) + final_pipe.to(self.device) + + checkpoint_path_2 = None + if len(cached_folders) > 2: + checkpoint_path_2 = os.path.join(cached_folders[2]) + + if interp == "sigmoid": + theta_func = CheckpointMergerPipeline.sigmoid + elif interp == "inv_sigmoid": + theta_func = CheckpointMergerPipeline.inv_sigmoid + elif interp == "add_diff": + theta_func = CheckpointMergerPipeline.add_difference + else: + theta_func = CheckpointMergerPipeline.weighted_sum + + # Find each module's state dict. + for attr in final_pipe.config.keys(): + if not attr.startswith("_"): + checkpoint_path_1 = os.path.join(cached_folders[1], attr) + if os.path.exists(checkpoint_path_1): + files = list( + ( + *glob.glob( + os.path.join(checkpoint_path_1, "*.safetensors") + ), + *glob.glob(os.path.join(checkpoint_path_1, "*.bin")), + ) + ) + checkpoint_path_1 = files[0] if len(files) > 0 else None + if len(cached_folders) < 3: + checkpoint_path_2 = None + else: + checkpoint_path_2 = os.path.join(cached_folders[2], attr) + if os.path.exists(checkpoint_path_2): + files = list( + ( + *glob.glob( + os.path.join(checkpoint_path_2, "*.safetensors") + ), + *glob.glob(os.path.join(checkpoint_path_2, "*.bin")), + ) + ) + checkpoint_path_2 = files[0] if len(files) > 0 else None + # For an attr if both checkpoint_path_1 and 2 are None, ignore. + # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match. + if checkpoint_path_1 is None and checkpoint_path_2 is None: + print(f"Skipping {attr}: not present in 2nd or 3d model") + continue + + + try: + module = getattr(final_pipe, attr) + if isinstance( + module, bool + ): # ignore requires_safety_checker boolean + continue + theta_0 = getattr(module, "state_dict") + theta_0 = theta_0() + + update_theta_0 = getattr(module, "load_state_dict") + + theta_1 = ( + safetensors.torch.load_file(checkpoint_path_1) + if ( + is_safetensors_available() + and checkpoint_path_1.endswith(".safetensors") + ) + else torch.load(checkpoint_path_1, map_location="cpu") + ) + + if attr in ['vae', 'text_encoder']: + print(f"Direct use theta1 {attr}: {checkpoint_path_1}") + update_theta_0(theta_1) + del theta_1 + del theta_0 + continue + + theta_2 = None + if checkpoint_path_2: + theta_2 = ( + safetensors.torch.load_file(checkpoint_path_2) + if ( + is_safetensors_available() + and checkpoint_path_2.endswith(".safetensors") + ) + else torch.load(checkpoint_path_2, map_location="cpu") + ) + + if not theta_0.keys() == theta_1.keys(): + print(f"Skipping {attr}: key mismatch") + continue + if theta_2 and not theta_1.keys() == theta_2.keys(): + print(f"Skipping {attr}:y mismatch") + except Exception as e: + print(f"Skipping {attr} do to an unexpected error: {str(e)}") + continue + print(f"MERGING {attr}") + + for key in theta_0.keys(): + if theta_2: + theta_0[key] = theta_func( + theta_0[key], theta_1[key], theta_2[key], alpha + ) + else: + theta_0[key] = theta_func( + theta_0[key], theta_1[key], None, alpha + ) + + del theta_1 + del theta_2 + update_theta_0(theta_0) + + del theta_0 + return final_pipe + + @staticmethod + def weighted_sum(theta0, theta1, theta2, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) + @staticmethod + def sigmoid(theta0, theta1, theta2, alpha): + alpha = alpha * alpha * (3 - (2 * alpha)) + return theta0 + ((theta1 - theta0) * alpha) + + # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) + @staticmethod + def inv_sigmoid(theta0, theta1, theta2, alpha): + import math + + alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) + return theta0 + ((theta1 - theta0) * alpha) + + @staticmethod + def add_difference(theta0, theta1, theta2, alpha): + # theta0 + (theta1 - theta2) * (1.0 - alpha) + + diff = (theta1 - theta2) * (1.0 - alpha) + # print(f"theta0.shape: {theta0.shape}, diff shape: {diff.shape}") + # theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + if theta0.shape != diff.shape: + theta0[:, 0:4, :, :] = theta0[:, 0:4, :, :] + diff + else: + theta0 = theta0 + diff + return theta0 + + +pipe = CheckpointMergerPipeline.from_pretrained("runwayml/stable-diffusion-inpainting") +merged_pipe = pipe.merge( + [ + "runwayml/stable-diffusion-inpainting", + #"SG161222/Realistic_Vision_V1.4", + "dreamlike-art/dreamlike-diffusion-1.0", + "runwayml/stable-diffusion-v1-5", + ], + force=True, + interp="add_diff", + alpha=0, +) + +merged_pipe = merged_pipe.to(torch.float16) +merged_pipe.save_pretrained("dreamlike-diffusion-1.0-inpainting", safe_serialization=True)