Sadjad Alikhani
commited on
Update lwm_model.py
Browse files- lwm_model.py +1 -47
lwm_model.py
CHANGED
@@ -10,11 +10,10 @@ import torch
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import numpy as np
|
13 |
-
|
14 |
from inference import *
|
15 |
from load_data import load_DeepMIMO_data
|
16 |
from input_preprocess import *
|
17 |
-
|
18 |
|
19 |
|
20 |
ELEMENT_LENGTH = 16
|
@@ -111,51 +110,6 @@ class EncoderLayer(nn.Module):
|
|
111 |
attn_outputs = self.norm(attn_outputs)
|
112 |
enc_outputs = self.pos_ffn(attn_outputs)
|
113 |
return enc_outputs, attn
|
114 |
-
|
115 |
-
# class LWM(torch.nn.Module):
|
116 |
-
# def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
|
117 |
-
# super().__init__()
|
118 |
-
|
119 |
-
# self.embedding = Embedding(element_length, d_model, max_len)
|
120 |
-
# self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
|
121 |
-
# self.linear = nn.Linear(d_model, d_model)
|
122 |
-
# self.norm = LayerNormalization(d_model)
|
123 |
-
|
124 |
-
# embed_weight = self.embedding.proj.weight
|
125 |
-
# d_model, n_dim = embed_weight.size()
|
126 |
-
# self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
127 |
-
# self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
|
128 |
-
# self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
129 |
-
|
130 |
-
# @classmethod
|
131 |
-
# def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
|
132 |
-
# # Define model
|
133 |
-
# model = cls().to(device)
|
134 |
-
|
135 |
-
# # Download the model weights (from a remote or local repository)
|
136 |
-
# ckpt_path = f'https://huggingface.co/sadjadalikhani/LWM/resolve/main/{ckpt_name}'
|
137 |
-
|
138 |
-
# # Load the model weights
|
139 |
-
# model.load_state_dict(torch.hub.load_state_dict_from_url(ckpt_path, map_location=device))
|
140 |
-
# print(f"Model loaded successfully from {ckpt_path} to {device}")
|
141 |
-
|
142 |
-
# return model
|
143 |
-
|
144 |
-
# def forward(self, input_ids, masked_pos):
|
145 |
-
# output = self.embedding(input_ids)
|
146 |
-
|
147 |
-
# for layer in self.layers:
|
148 |
-
# output, _ = layer(output)
|
149 |
-
|
150 |
-
# masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
|
151 |
-
# h_masked = torch.gather(output, 1, masked_pos)
|
152 |
-
# h_masked = self.norm(F.relu(self.linear(h_masked)))
|
153 |
-
# logits_lm = self.decoder(h_masked) + self.decoder_bias
|
154 |
-
|
155 |
-
# return logits_lm, output
|
156 |
-
|
157 |
-
from huggingface_hub import hf_hub_download
|
158 |
-
import torch
|
159 |
|
160 |
class LWM(torch.nn.Module):
|
161 |
def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
|
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import numpy as np
|
|
|
13 |
from inference import *
|
14 |
from load_data import load_DeepMIMO_data
|
15 |
from input_preprocess import *
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
|
18 |
|
19 |
ELEMENT_LENGTH = 16
|
|
|
110 |
attn_outputs = self.norm(attn_outputs)
|
111 |
enc_outputs = self.pos_ffn(attn_outputs)
|
112 |
return enc_outputs, attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
class LWM(torch.nn.Module):
|
115 |
def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
|