diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index a1cd881..d9745d4 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -91,7 +91,8 @@ class SD(InpaintModel): if kwargs['sd_cpu_textencoder']: logger.info("Run Stable Diffusion TextEncoder on CPU") - self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu')) + self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True) + self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True ) self.callbacks = kwargs.pop("callbacks", None) diff --git a/lama_cleaner/model/sd_pipeline.py b/lama_cleaner/model/sd_pipeline.py index 6616406..827a383 100644 --- a/lama_cleaner/model/sd_pipeline.py +++ b/lama_cleaner/model/sd_pipeline.py @@ -236,7 +236,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_encoder_device = self.text_encoder.device + + text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -248,7 +250,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -269,7 +271,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index cd70ad7..f2eac96 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -9,6 +9,8 @@ from lama_cleaner.model_manager import ModelManager from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / 'result' +save_dir.mkdir(exist_ok=True, parents=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -40,7 +42,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image. img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) res = model(img, mask, config) cv2.imwrite( - str(current_dir / gt_name), + str(save_dir / gt_name), res, [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], ) @@ -163,7 +165,12 @@ def test_sd(strategy, sampler): print(f"sd_step_{step}") sd_steps = 50 - model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], + model = ModelManager(name="sd1.4", + device=device, + hf_access_token=os.environ['HF_ACCESS_TOKEN'], + sd_run_local=False, + sd_disable_nsfw=False, + sd_cpu_textencoder=False, callbacks=[callback]) cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) cfg.sd_sampler = sampler @@ -187,7 +194,8 @@ def test_sd(strategy, sampler): @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("sampler", [SDSampler.ddim]) -def test_sd_run_local(strategy, sampler): +@pytest.mark.parametrize("disable_nsfw", [True, False]) +def test_sd_run_local(strategy, sampler, disable_nsfw): def callback(step: int): print(f"sd_step_{step}") @@ -195,11 +203,11 @@ def test_sd_run_local(strategy, sampler): model = ModelManager( name="sd1.4", device=device, + # hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None), hf_access_token=None, sd_run_local=True, - sd_disable_nsfw=True, + sd_disable_nsfw=disable_nsfw, sd_cpu_textencoder=True, - callbacks=[callback] ) cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) cfg.sd_sampler = sampler @@ -219,3 +227,4 @@ def test_sd_run_local(strategy, sampler): img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png", ) + diff --git a/requirements.txt b/requirements.txt index b029bb5..4715cee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=1.8.2 +torch>=1.9.0 opencv-python flask_cors flask==1.1.4