Fix precision error
Browse files- modeling_chatglm.py +27 -8
modeling_chatglm.py
CHANGED
@@ -5,7 +5,7 @@ import copy
|
|
5 |
import warnings
|
6 |
import re
|
7 |
import sys
|
8 |
-
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
11 |
import torch.nn.functional as F
|
@@ -177,15 +177,21 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|
177 |
|
178 |
|
179 |
class RMSNorm(torch.nn.Module):
|
180 |
-
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
181 |
super().__init__()
|
182 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
183 |
self.eps = eps
|
|
|
184 |
|
185 |
def forward(self, hidden_states: torch.Tensor):
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
return (self.weight * hidden_states).to(input_dtype)
|
191 |
|
@@ -515,10 +521,17 @@ class GLMBlock(torch.nn.Module):
|
|
515 |
|
516 |
self.fp32_residual_connection = config.fp32_residual_connection
|
517 |
|
518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
# Layernorm on the input data.
|
520 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
521 |
-
|
522 |
|
523 |
# Self attention.
|
524 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
@@ -593,7 +606,13 @@ class GLMTransformer(torch.nn.Module):
|
|
593 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
594 |
|
595 |
if self.post_layer_norm:
|
596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
# Final layer norm before output.
|
598 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
599 |
dtype=config.torch_dtype)
|
|
|
5 |
import warnings
|
6 |
import re
|
7 |
import sys
|
8 |
+
import functools
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
11 |
import torch.nn.functional as F
|
|
|
177 |
|
178 |
|
179 |
class RMSNorm(torch.nn.Module):
|
180 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, quantized=False, **kwargs):
|
181 |
super().__init__()
|
182 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
183 |
self.eps = eps
|
184 |
+
self.quantized = quantized
|
185 |
|
186 |
def forward(self, hidden_states: torch.Tensor):
|
187 |
+
if not self.quantized:
|
188 |
+
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
189 |
+
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
190 |
+
return self.weight * x_normed
|
191 |
+
else:
|
192 |
+
input_dtype = hidden_states.dtype
|
193 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
194 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
195 |
|
196 |
return (self.weight * hidden_states).to(input_dtype)
|
197 |
|
|
|
521 |
|
522 |
self.fp32_residual_connection = config.fp32_residual_connection
|
523 |
|
524 |
+
if config.rmsnorm:
|
525 |
+
if config.quantization_bit != 0:
|
526 |
+
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
527 |
+
else:
|
528 |
+
LayerNormFunc = RMSNorm
|
529 |
+
else:
|
530 |
+
LayerNormFunc = LayerNorm
|
531 |
+
|
532 |
# Layernorm on the input data.
|
533 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
534 |
+
dtype=config.torch_dtype)
|
535 |
|
536 |
# Self attention.
|
537 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
606 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
607 |
|
608 |
if self.post_layer_norm:
|
609 |
+
if config.rmsnorm:
|
610 |
+
if config.quantization_bit != 0:
|
611 |
+
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
612 |
+
else:
|
613 |
+
LayerNormFunc = RMSNorm
|
614 |
+
else:
|
615 |
+
LayerNormFunc = LayerNorm
|
616 |
# Final layer norm before output.
|
617 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
618 |
dtype=config.torch_dtype)
|