zits use structure_upsample_model
This commit is contained in:
parent
b0c5d22a5a
commit
cfcaf82a21
@ -8,7 +8,7 @@
|
|||||||
background-color: var(--modal-bg);
|
background-color: var(--modal-bg);
|
||||||
color: var(--modal-text-color);
|
color: var(--modal-text-color);
|
||||||
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2);
|
||||||
width: 700px;
|
width: 600px;
|
||||||
|
|
||||||
@include mobile {
|
@include mobile {
|
||||||
display: grid;
|
display: grid;
|
||||||
|
@ -42,17 +42,23 @@ def load_jit_model(url_or_path, device):
|
|||||||
else:
|
else:
|
||||||
model_path = download_model(url_or_path)
|
model_path = download_model(url_or_path)
|
||||||
logger.info(f"Load model from: {model_path}")
|
logger.info(f"Load model from: {model_path}")
|
||||||
model = torch.jit.load(model_path).to(device)
|
try:
|
||||||
|
model = torch.jit.load(model_path).to(device)
|
||||||
|
except:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
||||||
|
)
|
||||||
|
exit(-1)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||||
data = cv2.imencode(f".{ext}", image_numpy,
|
data = cv2.imencode(
|
||||||
[
|
f".{ext}",
|
||||||
int(cv2.IMWRITE_JPEG_QUALITY), 100,
|
image_numpy,
|
||||||
int(cv2.IMWRITE_PNG_COMPRESSION), 0
|
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||||
])[1]
|
)[1]
|
||||||
image_bytes = data.tobytes()
|
image_bytes = data.tobytes()
|
||||||
return image_bytes
|
return image_bytes
|
||||||
|
|
||||||
@ -95,7 +101,9 @@ def resize_max_size(
|
|||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None):
|
def pad_img_to_modulo(
|
||||||
|
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -41,7 +41,7 @@ def resize(img, height, width, center_crop=False):
|
|||||||
side = np.minimum(imgh, imgw)
|
side = np.minimum(imgh, imgw)
|
||||||
j = (imgh - side) // 2
|
j = (imgh - side) // 2
|
||||||
i = (imgw - side) // 2
|
i = (imgw - side) // 2
|
||||||
img = img[j: j + side, i: i + side, ...]
|
img = img[j : j + side, i : i + side, ...]
|
||||||
|
|
||||||
if imgh > height and imgw > width:
|
if imgh > height and imgw > width:
|
||||||
inter = cv2.INTER_AREA
|
inter = cv2.INTER_AREA
|
||||||
@ -219,7 +219,9 @@ class ZITS(InpaintModel):
|
|||||||
def init_model(self, device):
|
def init_model(self, device):
|
||||||
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
|
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
|
||||||
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
|
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
|
||||||
# self.structure_upsample = load_jit_model(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device)
|
self.structure_upsample = load_jit_model(
|
||||||
|
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device
|
||||||
|
)
|
||||||
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device)
|
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -227,7 +229,7 @@ class ZITS(InpaintModel):
|
|||||||
model_paths = [
|
model_paths = [
|
||||||
get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
|
get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
|
||||||
get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
|
get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
|
||||||
# get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
|
get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
|
||||||
get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
|
get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
|
||||||
]
|
]
|
||||||
return all([os.path.exists(it) for it in model_paths])
|
return all([os.path.exists(it) for it in model_paths])
|
||||||
@ -272,20 +274,27 @@ class ZITS(InpaintModel):
|
|||||||
# cv2.imwrite("line_pred.jpg", np_line_pred)
|
# cv2.imwrite("line_pred.jpg", np_line_pred)
|
||||||
# exit()
|
# exit()
|
||||||
|
|
||||||
# No structure_upsample_model
|
|
||||||
input_size = min(items["h"], items["w"])
|
input_size = min(items["h"], items["w"])
|
||||||
edge_pred = F.interpolate(
|
if input_size != 256 and input_size > 256:
|
||||||
edge_pred,
|
while edge_pred.shape[2] < input_size:
|
||||||
size=(input_size, input_size),
|
edge_pred = self.structure_upsample(edge_pred)
|
||||||
mode="bilinear",
|
edge_pred = torch.sigmoid((edge_pred + 2) * 2)
|
||||||
align_corners=False,
|
|
||||||
)
|
line_pred = self.structure_upsample(line_pred)
|
||||||
line_pred = F.interpolate(
|
line_pred = torch.sigmoid((line_pred + 2) * 2)
|
||||||
line_pred,
|
|
||||||
size=(input_size, input_size),
|
edge_pred = F.interpolate(
|
||||||
mode="bilinear",
|
edge_pred,
|
||||||
align_corners=False,
|
size=(input_size, input_size),
|
||||||
)
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
line_pred = F.interpolate(
|
||||||
|
line_pred,
|
||||||
|
size=(input_size, input_size),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
||||||
# cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
|
# cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
|
||||||
@ -308,12 +317,19 @@ class ZITS(InpaintModel):
|
|||||||
|
|
||||||
self.wireframe_edge_and_line(items, config.zits_wireframe)
|
self.wireframe_edge_and_line(items, config.zits_wireframe)
|
||||||
|
|
||||||
inpainted_image = self.inpaint(items["images"], items["masks"],
|
inpainted_image = self.inpaint(
|
||||||
items["edge"], items["line"],
|
items["images"],
|
||||||
items["rel_pos"], items["direct"])
|
items["masks"],
|
||||||
|
items["edge"],
|
||||||
|
items["line"],
|
||||||
|
items["rel_pos"],
|
||||||
|
items["direct"],
|
||||||
|
)
|
||||||
|
|
||||||
inpainted_image = inpainted_image * 255.0
|
inpainted_image = inpainted_image * 255.0
|
||||||
inpainted_image = inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
inpainted_image = (
|
||||||
|
inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
||||||
|
)
|
||||||
inpainted_image = inpainted_image[:, :, ::-1]
|
inpainted_image = inpainted_image[:, :, ::-1]
|
||||||
|
|
||||||
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
||||||
@ -362,7 +378,9 @@ class ZITS(InpaintModel):
|
|||||||
lines_tensor = torch.cat(lines_tensor, dim=0)
|
lines_tensor = torch.cat(lines_tensor, dim=0)
|
||||||
return lines_tensor.detach().to(self.device)
|
return lines_tensor.detach().to(self.device)
|
||||||
|
|
||||||
def sample_edge_line_logits(self, context, mask=None, iterations=1, add_v=0, mul_v=4):
|
def sample_edge_line_logits(
|
||||||
|
self, context, mask=None, iterations=1, add_v=0, mul_v=4
|
||||||
|
):
|
||||||
[img, edge, line] = context
|
[img, edge, line] = context
|
||||||
|
|
||||||
img = img * (1 - mask)
|
img = img * (1 - mask)
|
||||||
@ -391,7 +409,9 @@ class ZITS(InpaintModel):
|
|||||||
edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
||||||
line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
||||||
|
|
||||||
indices = torch.sort(edge_max_probs + line_max_probs, dim=-1, descending=True)[1]
|
indices = torch.sort(
|
||||||
|
edge_max_probs + line_max_probs, dim=-1, descending=True
|
||||||
|
)[1]
|
||||||
|
|
||||||
for ii in range(b):
|
for ii in range(b):
|
||||||
keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
|
keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
|
||||||
|
Loading…
Reference in New Issue
Block a user