|
import math |
|
import copy |
|
import torch |
|
from torch.nn import functional as F |
|
import torch.nn as nn |
|
import contextlib |
|
|
|
from sat import mpu |
|
from sat.transformer_defaults import standard_attention, attention_fn_default |
|
from sat.mpu.utils import split_tensor_along_last_dim, divide |
|
from sat.mpu.layers import ColumnParallelLinear |
|
from sat.model.base_model import BaseModel, BaseMixin |
|
from sat.model.position_embedding import RotaryEmbedding |
|
from sat.model.position_embedding import apply_rotary_pos_emb_index |
|
from sat.ops import LayerNorm |
|
|
|
|
|
class RotaryEmbeddingMixin(BaseMixin): |
|
def __init__( |
|
self, |
|
fp16, |
|
hidden_size, |
|
num_attention_heads, |
|
model_parallel_size, |
|
rotary_embedding_2d=True, |
|
): |
|
super().__init__() |
|
hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) |
|
self.hidden_size_per_attention_head = hidden_size_per_attention_head |
|
self.rotary_embedding_2d = rotary_embedding_2d |
|
self.num_attention_heads_per_partition = divide(num_attention_heads, model_parallel_size) |
|
self.rotary_emb = RotaryEmbedding( |
|
|
|
hidden_size_per_attention_head // 2 |
|
if rotary_embedding_2d |
|
else hidden_size_per_attention_head, |
|
base=10000, |
|
precision=torch.half if fp16 else torch.bfloat16, |
|
learnable=False, |
|
device=torch.cuda.current_device(), |
|
) |
|
|
|
|
|
def attention_forward(self, hidden_states, mask, **kw_args): |
|
attn = self.transformer.layers[kw_args["layer_id"]].attention |
|
attention_fn = attention_fn_default |
|
if "attention_fn" in attn.hooks: |
|
attention_fn = attn.hooks["attention_fn"] |
|
|
|
|
|
mixed_raw_layer = attn.query_key_value(hidden_states) |
|
|
|
|
|
new_tensor_shape = mixed_raw_layer.size()[:-1] + ( |
|
self.num_attention_heads_per_partition, |
|
3 * self.hidden_size_per_attention_head, |
|
) |
|
mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) |
|
|
|
|
|
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) |
|
|
|
dropout_fn = attn.attention_dropout if attn.training else None |
|
if self.rotary_embedding_2d: |
|
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) |
|
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) |
|
cos, sin = self.rotary_emb(q1, seq_len=kw_args["position_ids"].max() + 1) |
|
position_ids, block_position_ids = \ |
|
kw_args["position_ids"][:, 0, :].transpose(0, 1).contiguous(), \ |
|
kw_args["position_ids"][:, 1, :].transpose(0, 1).contiguous() |
|
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) |
|
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) |
|
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) |
|
key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) |
|
else: |
|
kw_args["position_ids"] = kw_args["position_ids"].transpose(0, 1) |
|
cos, sin = self.rotary_emb(value_layer, seq_len=kw_args["position_ids"].max() + 1) |
|
query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, kw_args["position_ids"]) |
|
|
|
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) |
|
output = attn.dense(context_layer) |
|
|
|
if attn.training: |
|
output = attn.output_dropout(output) |
|
|
|
return output |
|
|
|
|
|
class GEGLU(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.activation_fn = F.gelu |
|
|
|
def forward(self, x): |
|
|
|
x1, x2 = x.chunk(2, dim=(x.ndim - 1)) |
|
return x1 * self.activation_fn(x2) |
|
|
|
|
|
class DeepNormWithGLUMixin(BaseMixin): |
|
def __init__(self, num_layers, hidden_size, inner_hidden_size=None): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
self.hidden_size = hidden_size |
|
if inner_hidden_size is None: |
|
inner_hidden_size = 4 * hidden_size * 2 // 3 |
|
self.inner_hidden_size = inner_hidden_size |
|
|
|
def reinit(self): |
|
for layer in self.transformer.layers: |
|
del layer.mlp.dense_h_to_4h |
|
layer.mlp.dense_h_to_4h = ColumnParallelLinear( |
|
self.hidden_size, |
|
2 * self.inner_hidden_size, |
|
gather_output=False, |
|
bias=True, |
|
params_dtype=torch.half, |
|
module=self, |
|
name="dense_h_to_4h", |
|
skip_init=True, |
|
) |
|
del layer.mlp.activation_func |
|
layer.mlp.activation_func = GEGLU() |
|
|
|
def layer_forward(self, hidden_states, mask, *args, **kw_args): |
|
""" |
|
hidden_states: [seq_len, batch, hidden_size] |
|
mask: [(1, 1), seq_len, seq_len] |
|
""" |
|
layer = self.transformer.layers[kw_args["layer_id"]] |
|
|
|
|
|
attention_input = layer.input_layernorm(hidden_states) |
|
|
|
|
|
attention_output = layer.attention(attention_input, mask, **kw_args) |
|
|
|
|
|
alpha = (2 * self.num_layers) ** 0.5 |
|
hidden_states = attention_input * alpha + attention_output |
|
|
|
mlp_input = layer.post_attention_layernorm(hidden_states) |
|
|
|
|
|
mlp_output = layer.mlp(mlp_input, **kw_args) |
|
|
|
|
|
output = mlp_input * alpha + mlp_output |
|
|
|
return output |
|
|
|
|
|
class SelfAttentionWithFP32SoftmaxMixin(BaseMixin): |
|
def __init__(self, fp16, hidden_size, num_attention_heads, model_parallel_size): |
|
super().__init__() |
|
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) |
|
self.hidden_size_per_partition = divide(hidden_size, model_parallel_size) |
|
self.scale_mask_softmax = None |
|
self.fp16 = fp16 |
|
|
|
@staticmethod |
|
def attention_mask_func(attention_scores, attention_mask): |
|
attention_scores.masked_fill_(attention_mask, -10000.0) |
|
return attention_scores |
|
|
|
def attention_fn( |
|
self, |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
attention_mask, |
|
attention_dropout=None, |
|
log_attention_weights=None, |
|
scaling_attention_score=True, |
|
mems=None, |
|
**kwargs |
|
): |
|
|
|
mem = mems[kwargs["layer_id"]] if mems is not None else None |
|
|
|
|
|
seq_len, b, nh, hidden_size = key_layer.shape |
|
|
|
|
|
|
|
cache_kv = ( |
|
torch.stack((key_layer, value_layer)) |
|
.permute(2, 1, 0, 3, 4) |
|
.detach() |
|
.contiguous() |
|
.view(b, seq_len, nh * hidden_size * 2) |
|
) |
|
kwargs["output_this_layer"]["mem_kv"] = cache_kv |
|
|
|
if mem is not None: |
|
|
|
|
|
mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 1, 0, 3, 4) |
|
memk, memv = mem[0], mem[1] |
|
key_layer = torch.cat((memk, key_layer), dim=0) |
|
value_layer = torch.cat((memv, value_layer), dim=0) |
|
|
|
|
|
|
|
is_low_triangle = (attention_mask == ~torch.ones_like(attention_mask, dtype=torch.bool).tril()).all() |
|
is_full = (attention_mask is None) or (attention_mask == 0).all() |
|
if int(torch.__version__.split('.')[0]) >= 2 and (is_full or is_low_triangle): |
|
|
|
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p |
|
|
|
query_layer, key_layer, value_layer = query_layer.permute(1,2,0,3).contiguous(), key_layer.permute(1,2,0,3).contiguous(), value_layer.permute(1,2,0,3).contiguous() |
|
batch_size, num_query_heads = query_layer.shape[:2] |
|
num_kv_heads = key_layer.shape[1] |
|
key_layer = key_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *key_layer.shape[2:]) |
|
value_layer = value_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *value_layer.shape[2:]) |
|
|
|
if dropout_p > 0 and mpu.get_cuda_rng_tracker is not None: |
|
context = mpu.get_cuda_rng_tracker().fork() |
|
else: |
|
context = contextlib.nullcontext() |
|
|
|
with context: |
|
context_layer = torch.nn.functional.scaled_dot_product_attention( |
|
query_layer, key_layer, value_layer, |
|
attn_mask=None, |
|
dropout_p=dropout_p, |
|
is_causal=not is_full |
|
) |
|
|
|
|
|
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
return context_layer |
|
|
|
else: |
|
|
|
|
|
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) |
|
|
|
query_key_layer_scaling_coeff = float(kwargs["layer_id"] + 1) |
|
|
|
|
|
if scaling_attention_score: |
|
query_layer = query_layer / (math.sqrt(self.hidden_size_per_attention_head) * query_key_layer_scaling_coeff) |
|
|
|
|
|
|
|
|
|
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) |
|
|
|
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) |
|
|
|
matmul_result = torch.empty( |
|
output_size[0] * output_size[1], |
|
output_size[2], |
|
output_size[3], |
|
dtype=query_layer.dtype, |
|
device=torch.cuda.current_device(), |
|
) |
|
|
|
matmul_result = torch.baddbmm( |
|
matmul_result, |
|
query_layer.transpose(0, 1), |
|
key_layer.transpose(0, 1).transpose(1, 2), |
|
beta=0.0, |
|
alpha=1.0, |
|
) |
|
|
|
|
|
attention_scores = matmul_result.view(*output_size) |
|
|
|
if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()): |
|
|
|
attention_scores.masked_fill_(attention_mask.bool(), -float("inf")) |
|
|
|
attention_scores = attention_scores.float() |
|
attention_scores = attention_scores * query_key_layer_scaling_coeff |
|
|
|
|
|
attention_probs = F.softmax(attention_scores, dim=-1) |
|
|
|
if self.fp16: |
|
attention_probs = attention_probs.half() |
|
else: |
|
attention_probs = attention_probs.bfloat16() |
|
|
|
if attention_dropout is not None: |
|
if mpu.get_cuda_rng_tracker() is not None: |
|
with mpu.get_cuda_rng_tracker().fork(): |
|
attention_probs = attention_dropout(attention_probs) |
|
else: |
|
attention_probs = attention_dropout(attention_probs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) |
|
|
|
|
|
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) |
|
|
|
|
|
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) |
|
|
|
|
|
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) |
|
|
|
|
|
context_layer = context_layer.view(*output_size) |
|
|
|
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
return context_layer |
|
|
|
|
|
|
|
class FinalForwardMixin(BaseMixin): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def final_forward(self, logits, **kw_args): |
|
return F.linear(logits, self.transformer.word_embeddings.weight).transpose(0, 1).contiguous() |
|
|
|
|
|
class UntieFinalForwardMixin(BaseMixin): |
|
def __init__(self, hidden_size, vocab_size, untie_head_num, layernorm_epsilon=1.0e-5): |
|
super().__init__() |
|
|
|
self.lm_head = nn.ModuleList() |
|
for i in range(untie_head_num): |
|
self.lm_head.append( |
|
ColumnParallelLinear( |
|
hidden_size, |
|
2 * hidden_size, |
|
gather_output=True, |
|
bias=False, |
|
module=self, |
|
name=f"lm_head.{i}", |
|
) |
|
) |
|
|
|
self.head_layernorm = nn.ModuleList() |
|
for i in range(untie_head_num): |
|
self.head_layernorm.append( |
|
LayerNorm( |
|
hidden_size, |
|
eps=layernorm_epsilon |
|
) |
|
) |
|
self.activation_func=GEGLU() |
|
|
|
|
|
def final_forward(self, logits, **kwargs): |
|
logits = self.lm_head[1](logits) |
|
logits = self.activation_func(logits) |
|
logits = self.head_layernorm[1](logits) |
|
return F.linear(logits, self.transformer.word_embeddings.weight).transpose(0, 1).contiguous() |
|
|
|
|
|
class NonePositionEmbedding(BaseMixin): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def position_embedding_forward(self, position_ids, output_cross_layer, **kw_args): |
|
return None |
|
|
|
|
|
class WordEmbedding(BaseMixin): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args): |
|
return self.transformer.word_embeddings(input_ids).transpose(0, 1) |
|
|
|
|
|
class ProteinGLMForGeneration(BaseModel): |
|
def __init__(self, args, transformer=None, **kwargs): |
|
super().__init__( |
|
args, |
|
transformer=transformer, |
|
**kwargs |
|
) |
|
self.add_mixin("glu-deepnorm", DeepNormWithGLUMixin(args.num_layers, args.hidden_size, args.inner_hidden_size)) |
|
self.add_mixin( |
|
"fp32-softmax", |
|
SelfAttentionWithFP32SoftmaxMixin(args.fp16, args.hidden_size, args.num_attention_heads, args.model_parallel_size), |
|
) |
|
if args.untie_head: |
|
self.add_mixin("final-forward", UntieFinalForwardMixin(args.hidden_size, args.vocab_size, args.head_num)) |
|
else: |
|
self.add_mixin("final-forward", FinalForwardMixin()) |
|
self.add_mixin("non-position-embedding", NonePositionEmbedding()) |
|
del self.transformer.position_embeddings |
|
self.add_mixin("word-embedding", WordEmbedding()) |
|
self.add_mixin( |
|
"rotary-embedding", |
|
RotaryEmbeddingMixin( |
|
args.fp16, |
|
args.hidden_size, |
|
args.num_attention_heads, |
|
args.model_parallel_size, |
|
args.rotary_embedding_2d |
|
), |
|
) |
|
self.get_mixin("glu-deepnorm").reinit() |
|
|
|
@classmethod |
|
def add_model_specific_args(cls, parser): |
|
group = parser.add_argument_group('ProteinGLMForGeneration', 'ProteinGLMForGeneration Configurations') |
|
group.add_argument('--untie-head', action='store_true', help='untie-heads') |
|
group.add_argument('--head-num', default=1, type=int, help='head>1') |
|
group.add_argument('--infer-type', default=1, type=int, help='1 for Generation') |
|
group.add_argument('--rotary-embedding-2d', action='store_true', |
|
help='If set, use 2D rotary embedding for ProtenGLM.') |
|
return super().add_model_specific_args(parser) |
|
|