fix seed generator
This commit is contained in:
parent
3f27712991
commit
8e5e4892af
@ -52,8 +52,6 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
|||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
|
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
|
||||||
"""
|
"""
|
||||||
set_seed(config.sd_seed)
|
|
||||||
|
|
||||||
output = self.model(
|
output = self.model(
|
||||||
image=PIL.Image.fromarray(image),
|
image=PIL.Image.fromarray(image),
|
||||||
prompt=config.prompt,
|
prompt=config.prompt,
|
||||||
@ -62,6 +60,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
|||||||
image_guidance_scale=config.p2p_image_guidance_scale,
|
image_guidance_scale=config.p2p_image_guidance_scale,
|
||||||
guidance_scale=config.p2p_guidance_scale,
|
guidance_scale=config.p2p_guidance_scale,
|
||||||
output_type="np.array",
|
output_type="np.array",
|
||||||
|
generator=torch.manual_seed(config.sd_seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output = (output * 255).round().astype("uint8")
|
output = (output * 255).round().astype("uint8")
|
||||||
|
@ -51,14 +51,13 @@ class PaintByExample(DiffusionInpaintModel):
|
|||||||
mask: [H, W, 1] 255 means area to repaint
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
return: BGR IMAGE
|
return: BGR IMAGE
|
||||||
"""
|
"""
|
||||||
set_seed(config.paint_by_example_seed)
|
|
||||||
|
|
||||||
output = self.model(
|
output = self.model(
|
||||||
image=PIL.Image.fromarray(image),
|
image=PIL.Image.fromarray(image),
|
||||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||||
example_image=config.paint_by_example_example_image,
|
example_image=config.paint_by_example_example_image,
|
||||||
num_inference_steps=config.paint_by_example_steps,
|
num_inference_steps=config.paint_by_example_steps,
|
||||||
output_type='np.array',
|
output_type='np.array',
|
||||||
|
generator=torch.manual_seed(config.paint_by_example_seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output = (output * 255).round().astype("uint8")
|
output = (output * 255).round().astype("uint8")
|
||||||
|
@ -119,8 +119,6 @@ class SD(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.model.scheduler = scheduler
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
set_seed(config.sd_seed)
|
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
if config.sd_mask_blur != 0:
|
||||||
k = 2 * config.sd_mask_blur + 1
|
k = 2 * config.sd_mask_blur + 1
|
||||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||||
@ -138,6 +136,7 @@ class SD(DiffusionInpaintModel):
|
|||||||
callback=self.callback,
|
callback=self.callback,
|
||||||
height=img_h,
|
height=img_h,
|
||||||
width=img_w,
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output = (output * 255).round().astype("uint8")
|
output = (output * 255).round().astype("uint8")
|
||||||
|
Loading…
Reference in New Issue
Block a user