IOPaint/iopaint/model/power_paint/powerpaint_tokenizer.py
2024-04-25 22:12:51 +08:00

255 lines
9.2 KiB
Python

import copy
import random
from typing import Any, List, Union
from transformers import CLIPTokenizer
from iopaint.schema import PowerPaintTask
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
if task == PowerPaintTask.object_remove:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
elif task == PowerPaintTask.context_aware:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.shape_guided:
promptA = prompt + " P_shape"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.outpainting:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
else:
promptA = prompt + " P_obj"
promptB = prompt + " P_obj"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
return promptA, promptB, negative_promptA, negative_promptB
def task_to_prompt(task: PowerPaintTask):
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
"", "", task
)
return (
promptA.strip(),
promptB.strip(),
negative_promptA.strip(),
negative_promptB.strip(),
)
class PowerPaintTokenizer:
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
num_vec_per_token = 10
for placeholder_token in placeholder_tokens:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
output.append(ith_token)
self.token_map[placeholder_token] = output
def __getattr__(self, name: str) -> Any:
if name == "wrapped":
return super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
try:
return super().__getattr__(name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self,
text: Union[str, List[str]],
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(
self.replace_placeholder_tokens_in_text(
text[i], vector_shuffle=vector_shuffle
)
)
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(
self, text: Union[str, List[str]]
) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(
self, token_ids, return_raw: bool = False, *args, **kwargs
) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text