FLUX.1-schnell-fp8-flumina / flux_emphasis.py
aredden's picture
Fix unbound local error when pad_tokens=False
0a91429
from typing import TYPE_CHECKING, Optional
from pydash import flatten
import torch
from transformers.models.clip.tokenization_clip import CLIPTokenizer
from einops import repeat
if TYPE_CHECKING:
from flux_pipeline import FluxPipeline
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\\( - literal character '('
\\[ - literal character '['
\\) - literal character ')'
\\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\\(literal\\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
import re
re_attention = re.compile(
r"""
\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
\)|]|[^\\()\[\]:]+|:
""",
re.X,
)
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_tokens_with_weights(
clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False
):
"""
Get prompt token ids and weights, this function works for both prompt and negative prompt
Args:
pipe (CLIPTokenizer)
A CLIPTokenizer
prompt (str)
A prompt string with weights
Returns:
text_tokens (list)
A list contains token ids
text_weight (list)
A list contains the correspodent weight of token ids
Example:
import torch
from transformers import CLIPTokenizer
clip_tokenizer = CLIPTokenizer.from_pretrained(
"stablediffusionapi/deliberate-v2"
, subfolder = "tokenizer"
, dtype = torch.float16
)
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
clip_tokenizer = clip_tokenizer
,prompt = "a (red:1.5) cat"*70
)
"""
texts_and_weights = parse_prompt_attention(prompt)
text_tokens, text_weights = [], []
maxlen = clip_tokenizer.model_max_length
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = clip_tokenizer(
word, truncation=False, padding=False, add_special_tokens=False
).input_ids
# so that tokenize whatever length prompt
# the returned token is a 1d list: [320, 1125, 539, 320]
if debug:
print(
token,
"|FOR MODEL LEN{}|".format(maxlen),
clip_tokenizer.decode(
token, skip_special_tokens=True, clean_up_tokenization_spaces=True
),
)
# merge the new tokens to the all tokens holder: text_tokens
text_tokens = [*text_tokens, *token]
# each token chunk will come with one weight, like ['red cat', 2.0]
# need to expand weight for each token.
chunk_weights = [weight] * len(token)
# append the weight back to the weight holder: text_weights
text_weights = [*text_weights, *chunk_weights]
return text_tokens, text_weights
def group_tokens_and_weights(
token_ids: list,
weights: list,
pad_last_block=False,
bos=49406,
eos=49407,
max_length=77,
pad_tokens=True,
):
"""
Produce tokens and weights in groups and pad the missing tokens
Args:
token_ids (list)
The token ids from tokenizer
weights (list)
The weights list from function get_prompts_tokens_with_weights
pad_last_block (bool)
Control if fill the last token list to 75 tokens with eos
Returns:
new_token_ids (2d list)
new_weights (2d list)
Example:
token_groups,weight_groups = group_tokens_and_weights(
token_ids = token_id_list
, weights = token_weight_list
)
"""
# TODO: Possibly need to fix this, since this doesn't seem correct.
# Ignoring for now since I don't know what the consequences might be
# if changed to <= instead of <.
max_len = max_length - 2 if max_length < 77 else max_length
# this will be a 2d list
new_token_ids = []
new_weights = []
while len(token_ids) >= max_len:
# get the first 75 tokens
temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)]
temp_77_weights = [weights.pop(0) for _ in range(max_len)]
# extract token ids and weights
if pad_tokens:
if bos is not None:
temp_77_token_ids = [bos] + temp_77_token_ids + [eos]
temp_77_weights = [1.0] + temp_77_weights + [1.0]
else:
temp_77_token_ids = temp_77_token_ids + [eos]
temp_77_weights = temp_77_weights + [1.0]
# add 77 token and weights chunk to the holder list
new_token_ids.append(temp_77_token_ids)
new_weights.append(temp_77_weights)
# padding the left
if len(token_ids) > 0:
if pad_tokens:
padding_len = max_len - len(token_ids) if pad_last_block else 0
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
new_token_ids.append(temp_77_token_ids)
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
new_weights.append(temp_77_weights)
else:
new_token_ids.append(token_ids)
new_weights.append(weights)
return new_token_ids, new_weights
def standardize_tensor(
input_tensor: torch.Tensor, target_mean: float, target_std: float
) -> torch.Tensor:
"""
This function standardizes an input tensor so that it has a specific mean and standard deviation.
Parameters:
input_tensor (torch.Tensor): The tensor to standardize.
target_mean (float): The target mean for the tensor.
target_std (float): The target standard deviation for the tensor.
Returns:
torch.Tensor: The standardized tensor.
"""
# First, compute the mean and std of the input tensor
mean = input_tensor.mean()
std = input_tensor.std()
# Then, standardize the tensor to have a mean of 0 and std of 1
standardized_tensor = (input_tensor - mean) / std
# Finally, scale the tensor to the target mean and std
output_tensor = standardized_tensor * target_std + target_mean
return output_tensor
def apply_weights(
prompt_tokens: torch.Tensor,
weight_tensor: torch.Tensor,
token_embedding: torch.Tensor,
eos_token_id: int,
pad_last_block: bool = True,
) -> torch.FloatTensor:
mean = token_embedding.mean()
std = token_embedding.std()
if pad_last_block:
pooled_tensor = token_embedding[
torch.arange(token_embedding.shape[0], device=token_embedding.device),
(
prompt_tokens.to(dtype=torch.int, device=token_embedding.device)
== eos_token_id
)
.int()
.argmax(dim=-1),
]
else:
pooled_tensor = token_embedding[:, -1]
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
token_embedding[:, j] = (
pooled_tensor
+ (token_embedding[:, j] - pooled_tensor) * weight_tensor[j]
)
return standardize_tensor(token_embedding, mean, std)
@torch.inference_mode()
def get_weighted_text_embeddings_flux(
pipe: "FluxPipeline",
prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
target_device: Optional[torch.device] = torch.device("cuda:0"),
target_dtype: Optional[torch.dtype] = torch.bfloat16,
debug: bool = False,
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion XL
Args:
pipe (StableDiffusionPipeline)
prompt (str)
prompt_2 (str)
neg_prompt (str)
neg_prompt_2 (str)
num_images_per_prompt (int)
device (torch.device)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
"""
device = device or pipe._execution_device
eos = pipe.clip.tokenizer.eos_token_id
eos_2 = pipe.t5.tokenizer.eos_token_id
bos = pipe.clip.tokenizer.bos_token_id
bos_2 = pipe.t5.tokenizer.bos_token_id
clip = pipe.clip.hf_module
t5 = pipe.t5.hf_module
tokenizer_clip = pipe.clip.tokenizer
tokenizer_t5 = pipe.t5.tokenizer
t5_length = 512 if pipe.name == "flux-dev" else 256
clip_length = 77
# tokenizer 1
prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights(
tokenizer_clip, prompt, debug=debug
)
# tokenizer 2
prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights(
tokenizer_t5, prompt, debug=debug
)
prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights(
prompt_tokens_clip,
prompt_weights_clip,
pad_last_block=True,
bos=bos,
eos=eos,
max_length=clip_length,
)
prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights(
prompt_tokens_t5,
prompt_weights_t5,
pad_last_block=True,
bos=bos_2,
eos=eos_2,
max_length=t5_length,
pad_tokens=False,
)
prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped)
prompt_weights_t5 = flatten(prompt_weights_t5_grouped)
prompt_tokens_clip = flatten(prompt_tokens_clip_grouped)
prompt_weights_clip = flatten(prompt_weights_clip_grouped)
prompt_tokens_clip = tokenizer_clip.decode(
prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
prompt_tokens_clip = tokenizer_clip(
prompt_tokens_clip,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=clip_length,
return_tensors="pt",
).input_ids.to(device)
prompt_tokens_t5 = tokenizer_t5.decode(
prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
prompt_tokens_t5 = tokenizer_t5(
prompt_tokens_t5,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=t5_length,
return_tensors="pt",
).input_ids.to(device)
prompt_weights_t5 = torch.cat(
[
torch.tensor(prompt_weights_t5, dtype=torch.float32),
torch.full(
(t5_length - torch.tensor(prompt_weights_t5).numel(),),
1.0,
dtype=torch.float32,
),
],
dim=0,
).to(device)
clip_embeds = clip(
prompt_tokens_clip, output_hidden_states=True, attention_mask=None
)["pooler_output"]
if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1:
clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
weight_tensor_t5 = torch.tensor(
flatten(prompt_weights_t5), dtype=torch.float32, device=device
)
t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[
"last_hidden_state"
]
t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2)
if debug:
print(t5_embeds.shape)
if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1:
t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt)
txt_ids = torch.zeros(
num_images_per_prompt,
t5_embeds.shape[1],
3,
device=target_device,
dtype=target_dtype,
)
t5_embeds = t5_embeds.to(target_device, dtype=target_dtype)
clip_embeds = clip_embeds.to(target_device, dtype=target_dtype)
return (
clip_embeds,
t5_embeds,
txt_ids,
)