70af4845af
new file: inpaint/__main__.py new file: inpaint/api.py new file: inpaint/batch_processing.py new file: inpaint/benchmark.py new file: inpaint/cli.py new file: inpaint/const.py new file: inpaint/download.py new file: inpaint/file_manager/__init__.py new file: inpaint/file_manager/file_manager.py new file: inpaint/file_manager/storage_backends.py new file: inpaint/file_manager/utils.py new file: inpaint/helper.py new file: inpaint/installer.py new file: inpaint/model/__init__.py new file: inpaint/model/anytext/__init__.py new file: inpaint/model/anytext/anytext_model.py new file: inpaint/model/anytext/anytext_pipeline.py new file: inpaint/model/anytext/anytext_sd15.yaml new file: inpaint/model/anytext/cldm/__init__.py new file: inpaint/model/anytext/cldm/cldm.py new file: inpaint/model/anytext/cldm/ddim_hacked.py new file: inpaint/model/anytext/cldm/embedding_manager.py new file: inpaint/model/anytext/cldm/hack.py new file: inpaint/model/anytext/cldm/model.py new file: inpaint/model/anytext/cldm/recognizer.py new file: inpaint/model/anytext/ldm/__init__.py new file: inpaint/model/anytext/ldm/models/__init__.py new file: inpaint/model/anytext/ldm/models/autoencoder.py new file: inpaint/model/anytext/ldm/models/diffusion/__init__.py new file: inpaint/model/anytext/ldm/models/diffusion/ddim.py new file: inpaint/model/anytext/ldm/models/diffusion/ddpm.py new file: inpaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py new file: inpaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py new file: inpaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py new file: inpaint/model/anytext/ldm/models/diffusion/plms.py new file: inpaint/model/anytext/ldm/models/diffusion/sampling_util.py new file: inpaint/model/anytext/ldm/modules/__init__.py new file: inpaint/model/anytext/ldm/modules/attention.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/__init__.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/model.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/openaimodel.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py new file: inpaint/model/anytext/ldm/modules/diffusionmodules/util.py new file: inpaint/model/anytext/ldm/modules/distributions/__init__.py new file: inpaint/model/anytext/ldm/modules/distributions/distributions.py new file: inpaint/model/anytext/ldm/modules/ema.py new file: inpaint/model/anytext/ldm/modules/encoders/__init__.py new file: inpaint/model/anytext/ldm/modules/encoders/modules.py new file: inpaint/model/anytext/ldm/util.py new file: inpaint/model/anytext/main.py new file: inpaint/model/anytext/ocr_recog/RNN.py new file: inpaint/model/anytext/ocr_recog/RecCTCHead.py new file: inpaint/model/anytext/ocr_recog/RecModel.py new file: inpaint/model/anytext/ocr_recog/RecMv1_enhance.py new file: inpaint/model/anytext/ocr_recog/RecSVTR.py new file: inpaint/model/anytext/ocr_recog/__init__.py new file: inpaint/model/anytext/ocr_recog/common.py new file: inpaint/model/anytext/ocr_recog/en_dict.txt new file: inpaint/model/anytext/ocr_recog/ppocr_keys_v1.txt new file: inpaint/model/anytext/utils.py new file: inpaint/model/base.py new file: inpaint/model/brushnet/__init__.py new file: inpaint/model/brushnet/brushnet.py new file: inpaint/model/brushnet/brushnet_unet_forward.py new file: inpaint/model/brushnet/brushnet_wrapper.py new file: inpaint/model/brushnet/pipeline_brushnet.py new file: inpaint/model/brushnet/unet_2d_blocks.py new file: inpaint/model/controlnet.py new file: inpaint/model/ddim_sampler.py new file: inpaint/model/fcf.py new file: inpaint/model/helper/__init__.py new file: inpaint/model/helper/controlnet_preprocess.py new file: inpaint/model/helper/cpu_text_encoder.py new file: inpaint/model/helper/g_diffuser_bot.py new file: inpaint/model/instruct_pix2pix.py new file: inpaint/model/kandinsky.py new file: inpaint/model/lama.py new file: inpaint/model/ldm.py new file: inpaint/model/manga.py new file: inpaint/model/mat.py new file: inpaint/model/mi_gan.py new file: inpaint/model/opencv2.py new file: inpaint/model/original_sd_configs/__init__.py new file: inpaint/model/original_sd_configs/sd_xl_base.yaml new file: inpaint/model/original_sd_configs/sd_xl_refiner.yaml new file: inpaint/model/original_sd_configs/v1-inference.yaml new file: inpaint/model/original_sd_configs/v2-inference-v.yaml new file: inpaint/model/paint_by_example.py new file: inpaint/model/plms_sampler.py new file: inpaint/model/power_paint/__init__.py new file: inpaint/model/power_paint/pipeline_powerpaint.py new file: inpaint/model/power_paint/power_paint.py new file: inpaint/model/power_paint/power_paint_v2.py new file: inpaint/model/power_paint/powerpaint_tokenizer.py
175 lines
7.1 KiB
Python
175 lines
7.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
from .image_encoder import ImageEncoderViT
|
|
from .mask_decoder import MaskDecoder
|
|
from .prompt_encoder import PromptEncoder
|
|
|
|
|
|
class Sam(nn.Module):
|
|
mask_threshold: float = 0.0
|
|
image_format: str = "RGB"
|
|
|
|
def __init__(
|
|
self,
|
|
image_encoder: ImageEncoderViT,
|
|
prompt_encoder: PromptEncoder,
|
|
mask_decoder: MaskDecoder,
|
|
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
|
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
|
) -> None:
|
|
"""
|
|
SAM predicts object masks from an image and input prompts.
|
|
|
|
Arguments:
|
|
image_encoder (ImageEncoderViT): The backbone used to encode the
|
|
image into image embeddings that allow for efficient mask prediction.
|
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
|
and encoded prompts.
|
|
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
|
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
|
"""
|
|
super().__init__()
|
|
self.image_encoder = image_encoder
|
|
self.prompt_encoder = prompt_encoder
|
|
self.mask_decoder = mask_decoder
|
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
|
|
|
@property
|
|
def device(self) -> Any:
|
|
return self.pixel_mean.device
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
batched_input: List[Dict[str, Any]],
|
|
multimask_output: bool,
|
|
) -> List[Dict[str, torch.Tensor]]:
|
|
"""
|
|
Predicts masks end-to-end from provided images and prompts.
|
|
If prompts are not known in advance, using SamPredictor is
|
|
recommended over calling the model directly.
|
|
|
|
Arguments:
|
|
batched_input (list(dict)): A list over input images, each a
|
|
dictionary with the following keys. A prompt key can be
|
|
excluded if it is not present.
|
|
'image': The image as a torch tensor in 3xHxW format,
|
|
already transformed for input to the model.
|
|
'original_size': (tuple(int, int)) The original size of
|
|
the image before transformation, as (H, W).
|
|
'point_coords': (torch.Tensor) Batched point prompts for
|
|
this image, with shape BxNx2. Already transformed to the
|
|
input frame of the model.
|
|
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
|
with shape BxN.
|
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
|
Already transformed to the input frame of the model.
|
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
|
in the form Bx1xHxW.
|
|
multimask_output (bool): Whether the model should predict multiple
|
|
disambiguating masks, or return a single mask.
|
|
|
|
Returns:
|
|
(list(dict)): A list over input images, where each element is
|
|
as dictionary with the following keys.
|
|
'masks': (torch.Tensor) Batched binary mask predictions,
|
|
with shape BxCxHxW, where B is the number of input promts,
|
|
C is determiend by multimask_output, and (H, W) is the
|
|
original size of the image.
|
|
'iou_predictions': (torch.Tensor) The model's predictions
|
|
of mask quality, in shape BxC.
|
|
'low_res_logits': (torch.Tensor) Low resolution logits with
|
|
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
|
to subsequent iterations of prediction.
|
|
"""
|
|
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
|
image_embeddings = self.image_encoder(input_images)
|
|
|
|
outputs = []
|
|
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
|
if "point_coords" in image_record:
|
|
points = (image_record["point_coords"], image_record["point_labels"])
|
|
else:
|
|
points = None
|
|
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
points=points,
|
|
boxes=image_record.get("boxes", None),
|
|
masks=image_record.get("mask_inputs", None),
|
|
)
|
|
low_res_masks, iou_predictions = self.mask_decoder(
|
|
image_embeddings=curr_embedding.unsqueeze(0),
|
|
image_pe=self.prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=multimask_output,
|
|
)
|
|
masks = self.postprocess_masks(
|
|
low_res_masks,
|
|
input_size=image_record["image"].shape[-2:],
|
|
original_size=image_record["original_size"],
|
|
)
|
|
masks = masks > self.mask_threshold
|
|
outputs.append(
|
|
{
|
|
"masks": masks,
|
|
"iou_predictions": iou_predictions,
|
|
"low_res_logits": low_res_masks,
|
|
}
|
|
)
|
|
return outputs
|
|
|
|
def postprocess_masks(
|
|
self,
|
|
masks: torch.Tensor,
|
|
input_size: Tuple[int, ...],
|
|
original_size: Tuple[int, ...],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Remove padding and upscale masks to the original image size.
|
|
|
|
Arguments:
|
|
masks (torch.Tensor): Batched masks from the mask_decoder,
|
|
in BxCxHxW format.
|
|
input_size (tuple(int, int)): The size of the image input to the
|
|
model, in (H, W) format. Used to remove padding.
|
|
original_size (tuple(int, int)): The original size of the image
|
|
before resizing for input to the model, in (H, W) format.
|
|
|
|
Returns:
|
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
|
is given by original_size.
|
|
"""
|
|
masks = F.interpolate(
|
|
masks,
|
|
(self.image_encoder.img_size, self.image_encoder.img_size),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
masks = masks[..., : input_size[0], : input_size[1]]
|
|
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
|
return masks
|
|
|
|
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Normalize pixel values and pad to a square input."""
|
|
# Normalize colors
|
|
x = (x - self.pixel_mean) / self.pixel_std
|
|
|
|
# Pad
|
|
h, w = x.shape[-2:]
|
|
padh = self.image_encoder.img_size - h
|
|
padw = self.image_encoder.img_size - w
|
|
x = F.pad(x, (0, padw, 0, padh))
|
|
return x
|