wip fp16 mat
This commit is contained in:
parent
094b3c4f69
commit
7e028c3908
@ -21,6 +21,7 @@ from lama_cleaner.model.utils import (
|
||||
MinibatchStdLayer,
|
||||
to_2tuple,
|
||||
normalize_2nd_moment,
|
||||
set_seed,
|
||||
)
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@ -361,6 +362,7 @@ class MappingNet(torch.nn.Module):
|
||||
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
||||
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
||||
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
||||
torch_dtype=torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
@ -369,6 +371,7 @@ class MappingNet(torch.nn.Module):
|
||||
self.num_ws = num_ws
|
||||
self.num_layers = num_layers
|
||||
self.w_avg_beta = w_avg_beta
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
if embed_features is None:
|
||||
embed_features = w_dim
|
||||
@ -399,14 +402,16 @@ class MappingNet(torch.nn.Module):
|
||||
def forward(
|
||||
self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
|
||||
):
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
# Embed, normalize, and concat inputs.
|
||||
x = None
|
||||
with torch.autograd.profiler.record_function("input"):
|
||||
if self.z_dim > 0:
|
||||
x = normalize_2nd_moment(z.to(torch.float32))
|
||||
if self.c_dim > 0:
|
||||
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
||||
x = torch.cat([x, y], dim=1) if x is not None else y
|
||||
if self.z_dim > 0:
|
||||
x = normalize_2nd_moment(z)
|
||||
if self.c_dim > 0:
|
||||
y = normalize_2nd_moment(self.embed(c))
|
||||
x = torch.cat([x, y], dim=1) if x is not None else y
|
||||
|
||||
# Main layers.
|
||||
for idx in range(self.num_layers):
|
||||
@ -415,26 +420,21 @@ class MappingNet(torch.nn.Module):
|
||||
|
||||
# Update moving average of W.
|
||||
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
||||
with torch.autograd.profiler.record_function("update_w_avg"):
|
||||
self.w_avg.copy_(
|
||||
x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
|
||||
)
|
||||
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
||||
|
||||
# Broadcast.
|
||||
if self.num_ws is not None:
|
||||
with torch.autograd.profiler.record_function("broadcast"):
|
||||
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
||||
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
||||
|
||||
# Apply truncation.
|
||||
if truncation_psi != 1:
|
||||
with torch.autograd.profiler.record_function("truncate"):
|
||||
assert self.w_avg_beta is not None
|
||||
if self.num_ws is None or truncation_cutoff is None:
|
||||
x = self.w_avg.lerp(x, truncation_psi)
|
||||
else:
|
||||
x[:, :truncation_cutoff] = self.w_avg.lerp(
|
||||
x[:, :truncation_cutoff], truncation_psi
|
||||
)
|
||||
assert self.w_avg_beta is not None
|
||||
if self.num_ws is None or truncation_cutoff is None:
|
||||
x = self.w_avg.lerp(x, truncation_psi)
|
||||
else:
|
||||
x[:, :truncation_cutoff] = self.w_avg.lerp(
|
||||
x[:, :truncation_cutoff], truncation_psi
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@ -713,7 +713,6 @@ class WindowAttention(nn.Module):
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
@ -1058,7 +1057,6 @@ class BasicLayer(nn.Module):
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
@ -1882,14 +1880,22 @@ class MAT(InpaintModel):
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
seed = 240 # pick up a random number
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
|
||||
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5)
|
||||
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
|
||||
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
||||
self.torch_dtype = torch.float16
|
||||
G = Generator(
|
||||
z_dim=512,
|
||||
c_dim=0,
|
||||
w_dim=512,
|
||||
img_resolution=512,
|
||||
img_channels=3,
|
||||
mapping_kwargs={"torch_dtype": self.torch_dtype},
|
||||
)
|
||||
# fmt: off
|
||||
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5).to(self.torch_dtype)
|
||||
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device)
|
||||
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
@ -1909,8 +1915,10 @@ class MAT(InpaintModel):
|
||||
mask = 255 - mask
|
||||
mask = norm_img(mask)
|
||||
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
image = (
|
||||
torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device)
|
||||
)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device)
|
||||
|
||||
output = self.model(
|
||||
image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"
|
||||
|
Loading…
Reference in New Issue
Block a user