ejbejaranos's picture
Training in progress, step 500
eee8c91 verified
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-5): 6 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): BitLinear(
in_features=4096, out_features=4096, bias=False
(rms_norm): LlamaRMSNorm((4096,), eps=1e-06)
)
(k_proj): BitLinear(
in_features=4096, out_features=4096, bias=False
(rms_norm): LlamaRMSNorm((4096,), eps=1e-06)
)
(v_proj): BitLinear(
in_features=4096, out_features=4096, bias=False
(rms_norm): LlamaRMSNorm((4096,), eps=1e-06)
)
(o_proj): BitLinear(
in_features=4096, out_features=4096, bias=False
(rms_norm): LlamaRMSNorm((4096,), eps=1e-06)
)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): BitLinear(
in_features=4096, out_features=2048, bias=False
(rms_norm): LlamaRMSNorm((4096,), eps=1e-06)
)
(up_proj): BitLinear(
in_features=4096, out_features=2048, bias=False
(rms_norm): LlamaRMSNorm((4096,), eps=1e-06)
)
(down_proj): BitLinear(
in_features=2048, out_features=4096, bias=False
(rms_norm): LlamaRMSNorm((2048,), eps=1e-06)
)
(act_fn): SiLU()
)
(input_layernorm): Identity()
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-06)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)