wip fp16 mat

This commit is contained in:
Qing 2023-03-25 21:29:13 +08:00
parent 094b3c4f69
commit 7e028c3908

View File

@ -21,6 +21,7 @@ from lama_cleaner.model.utils import (
MinibatchStdLayer,
to_2tuple,
normalize_2nd_moment,
set_seed,
)
from lama_cleaner.schema import Config
@ -361,6 +362,7 @@ class MappingNet(torch.nn.Module):
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
torch_dtype=torch.float32,
):
super().__init__()
self.z_dim = z_dim
@ -369,6 +371,7 @@ class MappingNet(torch.nn.Module):
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
self.torch_dtype = torch_dtype
if embed_features is None:
embed_features = w_dim
@ -399,14 +402,16 @@ class MappingNet(torch.nn.Module):
def forward(
self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
):
import ipdb
ipdb.set_trace()
# Embed, normalize, and concat inputs.
x = None
with torch.autograd.profiler.record_function("input"):
if self.z_dim > 0:
x = normalize_2nd_moment(z.to(torch.float32))
if self.c_dim > 0:
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
x = torch.cat([x, y], dim=1) if x is not None else y
if self.z_dim > 0:
x = normalize_2nd_moment(z)
if self.c_dim > 0:
y = normalize_2nd_moment(self.embed(c))
x = torch.cat([x, y], dim=1) if x is not None else y
# Main layers.
for idx in range(self.num_layers):
@ -415,26 +420,21 @@ class MappingNet(torch.nn.Module):
# Update moving average of W.
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
with torch.autograd.profiler.record_function("update_w_avg"):
self.w_avg.copy_(
x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
)
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
# Broadcast.
if self.num_ws is not None:
with torch.autograd.profiler.record_function("broadcast"):
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
# Apply truncation.
if truncation_psi != 1:
with torch.autograd.profiler.record_function("truncate"):
assert self.w_avg_beta is not None
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x, truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(
x[:, :truncation_cutoff], truncation_psi
)
assert self.w_avg_beta is not None
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x, truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(
x[:, :truncation_cutoff], truncation_psi
)
return x
@ -713,7 +713,6 @@ class WindowAttention(nn.Module):
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
@ -1058,7 +1057,6 @@ class BasicLayer(nn.Module):
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -1882,14 +1880,22 @@ class MAT(InpaintModel):
def init_model(self, device, **kwargs):
seed = 240 # pick up a random number
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed(seed)
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5)
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
self.label = torch.zeros([1, self.model.c_dim], device=device)
self.torch_dtype = torch.float16
G = Generator(
z_dim=512,
c_dim=0,
w_dim=512,
img_resolution=512,
img_channels=3,
mapping_kwargs={"torch_dtype": self.torch_dtype},
)
# fmt: off
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5).to(self.torch_dtype)
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device)
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
# fmt: on
@staticmethod
def is_downloaded() -> bool:
@ -1909,8 +1915,10 @@ class MAT(InpaintModel):
mask = 255 - mask
mask = norm_img(mask)
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
image = (
torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device)
)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device)
output = self.model(
image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"