fix seed generator

This commit is contained in:
Qing 2023-03-01 09:13:23 +08:00
parent 3f27712991
commit 8e5e4892af
3 changed files with 3 additions and 6 deletions

View File

@ -52,8 +52,6 @@ class InstructPix2Pix(DiffusionInpaintModel):
return: BGR IMAGE return: BGR IMAGE
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0] edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
""" """
set_seed(config.sd_seed)
output = self.model( output = self.model(
image=PIL.Image.fromarray(image), image=PIL.Image.fromarray(image),
prompt=config.prompt, prompt=config.prompt,
@ -62,6 +60,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
image_guidance_scale=config.p2p_image_guidance_scale, image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale, guidance_scale=config.p2p_guidance_scale,
output_type="np.array", output_type="np.array",
generator=torch.manual_seed(config.sd_seed)
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -51,14 +51,13 @@ class PaintByExample(DiffusionInpaintModel):
mask: [H, W, 1] 255 means area to repaint mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE return: BGR IMAGE
""" """
set_seed(config.paint_by_example_seed)
output = self.model( output = self.model(
image=PIL.Image.fromarray(image), image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image, example_image=config.paint_by_example_example_image,
num_inference_steps=config.paint_by_example_steps, num_inference_steps=config.paint_by_example_steps,
output_type='np.array', output_type='np.array',
generator=torch.manual_seed(config.paint_by_example_seed)
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -119,8 +119,6 @@ class SD(DiffusionInpaintModel):
self.model.scheduler = scheduler self.model.scheduler = scheduler
set_seed(config.sd_seed)
if config.sd_mask_blur != 0: if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1 k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
@ -138,6 +136,7 @@ class SD(DiffusionInpaintModel):
callback=self.callback, callback=self.callback,
height=img_h, height=img_h,
width=img_w, width=img_w,
generator=torch.manual_seed(config.sd_seed)
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")