LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(128256, 4096) (layers): ModuleList( (0-5): 6 x LlamaDecoderLayer( (self_attn): LlamaSdpaAttention( (q_proj): BitLinear( in_features=4096, out_features=4096, bias=False (rms_norm): LlamaRMSNorm((4096,), eps=1e-06) ) (k_proj): BitLinear( in_features=4096, out_features=4096, bias=False (rms_norm): LlamaRMSNorm((4096,), eps=1e-06) ) (v_proj): BitLinear( in_features=4096, out_features=4096, bias=False (rms_norm): LlamaRMSNorm((4096,), eps=1e-06) ) (o_proj): BitLinear( in_features=4096, out_features=4096, bias=False (rms_norm): LlamaRMSNorm((4096,), eps=1e-06) ) (rotary_emb): LlamaRotaryEmbedding() ) (mlp): LlamaMLP( (gate_proj): BitLinear( in_features=4096, out_features=2048, bias=False (rms_norm): LlamaRMSNorm((4096,), eps=1e-06) ) (up_proj): BitLinear( in_features=4096, out_features=2048, bias=False (rms_norm): LlamaRMSNorm((4096,), eps=1e-06) ) (down_proj): BitLinear( in_features=2048, out_features=4096, bias=False (rms_norm): LlamaRMSNorm((2048,), eps=1e-06) ) (act_fn): SiLU() ) (input_layernorm): Identity() (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06) ) ) (norm): LlamaRMSNorm((4096,), eps=1e-06) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=4096, out_features=128256, bias=False) )