enable text_encoder cpu
This commit is contained in:
parent
db1d7d5c48
commit
dba7b01da7
@ -91,7 +91,8 @@ class SD(InpaintModel):
|
|||||||
|
|
||||||
if kwargs['sd_cpu_textencoder']:
|
if kwargs['sd_cpu_textencoder']:
|
||||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'))
|
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
|
||||||
|
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
|
||||||
|
|
||||||
self.callbacks = kwargs.pop("callbacks", None)
|
self.callbacks = kwargs.pop("callbacks", None)
|
||||||
|
|
||||||
|
@ -236,7 +236,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
text_encoder_device = self.text_encoder.device
|
||||||
|
|
||||||
|
text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
@ -248,7 +250,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
uncond_input = self.tokenizer(
|
uncond_input = self.tokenizer(
|
||||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
@ -269,7 +271,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ from lama_cleaner.model_manager import ModelManager
|
|||||||
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
|
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / 'result'
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +42,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.
|
|||||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
cv2.imwrite(
|
cv2.imwrite(
|
||||||
str(current_dir / gt_name),
|
str(save_dir / gt_name),
|
||||||
res,
|
res,
|
||||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||||
)
|
)
|
||||||
@ -163,7 +165,12 @@ def test_sd(strategy, sampler):
|
|||||||
print(f"sd_step_{step}")
|
print(f"sd_step_{step}")
|
||||||
|
|
||||||
sd_steps = 50
|
sd_steps = 50
|
||||||
model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
model = ModelManager(name="sd1.4",
|
||||||
|
device=device,
|
||||||
|
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
||||||
|
sd_run_local=False,
|
||||||
|
sd_disable_nsfw=False,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
callbacks=[callback])
|
callbacks=[callback])
|
||||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
@ -187,7 +194,8 @@ def test_sd(strategy, sampler):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
def test_sd_run_local(strategy, sampler):
|
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||||
|
def test_sd_run_local(strategy, sampler, disable_nsfw):
|
||||||
def callback(step: int):
|
def callback(step: int):
|
||||||
print(f"sd_step_{step}")
|
print(f"sd_step_{step}")
|
||||||
|
|
||||||
@ -195,11 +203,11 @@ def test_sd_run_local(strategy, sampler):
|
|||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name="sd1.4",
|
name="sd1.4",
|
||||||
device=device,
|
device=device,
|
||||||
|
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
|
||||||
hf_access_token=None,
|
hf_access_token=None,
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=True,
|
sd_disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=True,
|
sd_cpu_textencoder=True,
|
||||||
callbacks=[callback]
|
|
||||||
)
|
)
|
||||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
@ -219,3 +227,4 @@ def test_sd_run_local(strategy, sampler):
|
|||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
torch>=1.8.2
|
torch>=1.9.0
|
||||||
opencv-python
|
opencv-python
|
||||||
flask_cors
|
flask_cors
|
||||||
flask==1.1.4
|
flask==1.1.4
|
||||||
|
Loading…
Reference in New Issue
Block a user