255 lines
9.2 KiB
Python
255 lines
9.2 KiB
Python
import copy
|
|
import random
|
|
from typing import Any, List, Union
|
|
from transformers import CLIPTokenizer
|
|
|
|
from iopaint.schema import PowerPaintTask
|
|
|
|
|
|
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
|
|
if task == PowerPaintTask.object_remove:
|
|
promptA = prompt + " P_ctxt"
|
|
promptB = prompt + " P_ctxt"
|
|
negative_promptA = negative_prompt + " P_obj"
|
|
negative_promptB = negative_prompt + " P_obj"
|
|
elif task == PowerPaintTask.context_aware:
|
|
promptA = prompt + " P_ctxt"
|
|
promptB = prompt + " P_ctxt"
|
|
negative_promptA = negative_prompt
|
|
negative_promptB = negative_prompt
|
|
elif task == PowerPaintTask.shape_guided:
|
|
promptA = prompt + " P_shape"
|
|
promptB = prompt + " P_ctxt"
|
|
negative_promptA = negative_prompt
|
|
negative_promptB = negative_prompt
|
|
elif task == PowerPaintTask.outpainting:
|
|
promptA = prompt + " P_ctxt"
|
|
promptB = prompt + " P_ctxt"
|
|
negative_promptA = negative_prompt + " P_obj"
|
|
negative_promptB = negative_prompt + " P_obj"
|
|
else:
|
|
promptA = prompt + " P_obj"
|
|
promptB = prompt + " P_obj"
|
|
negative_promptA = negative_prompt
|
|
negative_promptB = negative_prompt
|
|
|
|
return promptA, promptB, negative_promptA, negative_promptB
|
|
|
|
|
|
def task_to_prompt(task: PowerPaintTask):
|
|
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
|
|
"", "", task
|
|
)
|
|
return (
|
|
promptA.strip(),
|
|
promptB.strip(),
|
|
negative_promptA.strip(),
|
|
negative_promptB.strip(),
|
|
)
|
|
|
|
|
|
class PowerPaintTokenizer:
|
|
def __init__(self, tokenizer: CLIPTokenizer):
|
|
self.wrapped = tokenizer
|
|
self.token_map = {}
|
|
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
|
|
num_vec_per_token = 10
|
|
for placeholder_token in placeholder_tokens:
|
|
output = []
|
|
for i in range(num_vec_per_token):
|
|
ith_token = placeholder_token + f"_{i}"
|
|
output.append(ith_token)
|
|
self.token_map[placeholder_token] = output
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name == "wrapped":
|
|
return super().__getattr__("wrapped")
|
|
|
|
try:
|
|
return getattr(self.wrapped, name)
|
|
except AttributeError:
|
|
try:
|
|
return super().__getattr__(name)
|
|
except AttributeError:
|
|
raise AttributeError(
|
|
"'name' cannot be found in both "
|
|
f"'{self.__class__.__name__}' and "
|
|
f"'{self.__class__.__name__}.tokenizer'."
|
|
)
|
|
|
|
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
|
"""Attempt to add tokens to the tokenizer.
|
|
|
|
Args:
|
|
tokens (Union[str, List[str]]): The tokens to be added.
|
|
"""
|
|
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
|
assert num_added_tokens != 0, (
|
|
f"The tokenizer already contains the token {tokens}. Please pass "
|
|
"a different `placeholder_token` that is not already in the "
|
|
"tokenizer."
|
|
)
|
|
|
|
def get_token_info(self, token: str) -> dict:
|
|
"""Get the information of a token, including its start and end index in
|
|
the current tokenizer.
|
|
|
|
Args:
|
|
token (str): The token to be queried.
|
|
|
|
Returns:
|
|
dict: The information of the token, including its start and end
|
|
index in current tokenizer.
|
|
"""
|
|
token_ids = self.__call__(token).input_ids
|
|
start, end = token_ids[1], token_ids[-2] + 1
|
|
return {"name": token, "start": start, "end": end}
|
|
|
|
def add_placeholder_token(
|
|
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
|
|
):
|
|
"""Add placeholder tokens to the tokenizer.
|
|
|
|
Args:
|
|
placeholder_token (str): The placeholder token to be added.
|
|
num_vec_per_token (int, optional): The number of vectors of
|
|
the added placeholder token.
|
|
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
|
"""
|
|
output = []
|
|
if num_vec_per_token == 1:
|
|
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
|
output.append(placeholder_token)
|
|
else:
|
|
output = []
|
|
for i in range(num_vec_per_token):
|
|
ith_token = placeholder_token + f"_{i}"
|
|
self.try_adding_tokens(ith_token, *args, **kwargs)
|
|
output.append(ith_token)
|
|
|
|
for token in self.token_map:
|
|
if token in placeholder_token:
|
|
raise ValueError(
|
|
f"The tokenizer already has placeholder token {token} "
|
|
f"that can get confused with {placeholder_token} "
|
|
"keep placeholder tokens independent"
|
|
)
|
|
self.token_map[placeholder_token] = output
|
|
|
|
def replace_placeholder_tokens_in_text(
|
|
self,
|
|
text: Union[str, List[str]],
|
|
vector_shuffle: bool = False,
|
|
prop_tokens_to_load: float = 1.0,
|
|
) -> Union[str, List[str]]:
|
|
"""Replace the keywords in text with placeholder tokens. This function
|
|
will be called in `self.__call__` and `self.encode`.
|
|
|
|
Args:
|
|
text (Union[str, List[str]]): The text to be processed.
|
|
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
|
Defaults to False.
|
|
prop_tokens_to_load (float, optional): The proportion of tokens to
|
|
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
|
|
|
Returns:
|
|
Union[str, List[str]]: The processed text.
|
|
"""
|
|
if isinstance(text, list):
|
|
output = []
|
|
for i in range(len(text)):
|
|
output.append(
|
|
self.replace_placeholder_tokens_in_text(
|
|
text[i], vector_shuffle=vector_shuffle
|
|
)
|
|
)
|
|
return output
|
|
|
|
for placeholder_token in self.token_map:
|
|
if placeholder_token in text:
|
|
tokens = self.token_map[placeholder_token]
|
|
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
|
if vector_shuffle:
|
|
tokens = copy.copy(tokens)
|
|
random.shuffle(tokens)
|
|
text = text.replace(placeholder_token, " ".join(tokens))
|
|
return text
|
|
|
|
def replace_text_with_placeholder_tokens(
|
|
self, text: Union[str, List[str]]
|
|
) -> Union[str, List[str]]:
|
|
"""Replace the placeholder tokens in text with the original keywords.
|
|
This function will be called in `self.decode`.
|
|
|
|
Args:
|
|
text (Union[str, List[str]]): The text to be processed.
|
|
|
|
Returns:
|
|
Union[str, List[str]]: The processed text.
|
|
"""
|
|
if isinstance(text, list):
|
|
output = []
|
|
for i in range(len(text)):
|
|
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
|
return output
|
|
|
|
for placeholder_token, tokens in self.token_map.items():
|
|
merged_tokens = " ".join(tokens)
|
|
if merged_tokens in text:
|
|
text = text.replace(merged_tokens, placeholder_token)
|
|
return text
|
|
|
|
def __call__(
|
|
self,
|
|
text: Union[str, List[str]],
|
|
*args,
|
|
vector_shuffle: bool = False,
|
|
prop_tokens_to_load: float = 1.0,
|
|
**kwargs,
|
|
):
|
|
"""The call function of the wrapper.
|
|
|
|
Args:
|
|
text (Union[str, List[str]]): The text to be tokenized.
|
|
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
|
Defaults to False.
|
|
prop_tokens_to_load (float, optional): The proportion of tokens to
|
|
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
|
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
|
"""
|
|
replaced_text = self.replace_placeholder_tokens_in_text(
|
|
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
|
)
|
|
|
|
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
|
|
|
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
|
"""Encode the passed text to token index.
|
|
|
|
Args:
|
|
text (Union[str, List[str]]): The text to be encode.
|
|
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
|
"""
|
|
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
|
return self.wrapped(replaced_text, *args, **kwargs)
|
|
|
|
def decode(
|
|
self, token_ids, return_raw: bool = False, *args, **kwargs
|
|
) -> Union[str, List[str]]:
|
|
"""Decode the token index to text.
|
|
|
|
Args:
|
|
token_ids: The token index to be decoded.
|
|
return_raw: Whether keep the placeholder token in the text.
|
|
Defaults to False.
|
|
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
|
|
|
Returns:
|
|
Union[str, List[str]]: The decoded text.
|
|
"""
|
|
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
|
if return_raw:
|
|
return text
|
|
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
|
return replaced_text
|