This commit is contained in:
Qing 2023-11-14 14:02:10 +08:00
parent 557e28aff9
commit d061b07029

View File

@ -24,13 +24,6 @@ class Kandinsky(DiffusionInpaintModel):
"torch_dtype": torch_dtype, "torch_dtype": torch_dtype,
} }
# self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
# self.prior_name, **model_kwargs
# ).to("cpu")
#
# self.model = KandinskyInpaintPipeline.from_pretrained(
# self.model_name, **model_kwargs
# ).to(device)
self.model = AutoPipelineForInpainting.from_pretrained( self.model = AutoPipelineForInpainting.from_pretrained(
self.model_name, **model_kwargs self.model_name, **model_kwargs
).to(device) ).to(device)
@ -54,6 +47,7 @@ class Kandinsky(DiffusionInpaintModel):
mask = mask.astype(np.float32) / 255 mask = mask.astype(np.float32) / 255
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
# kandinsky 没有 strength
output = self.model( output = self.model(
prompt=config.prompt, prompt=config.prompt,
negative_prompt=config.negative_prompt, negative_prompt=config.negative_prompt,
@ -65,6 +59,7 @@ class Kandinsky(DiffusionInpaintModel):
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback=self.callback,
generator=generator,
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")