Sadjad Alikhani commited on
Commit
f3de1b7
·
verified ·
1 Parent(s): cd7cb8b

Update lwm_model.py

Browse files
Files changed (1) hide show
  1. lwm_model.py +1 -47
lwm_model.py CHANGED
@@ -10,11 +10,10 @@ import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import numpy as np
13
-
14
  from inference import *
15
  from load_data import load_DeepMIMO_data
16
  from input_preprocess import *
17
- #from lwm_model import LWM, load_model
18
 
19
 
20
  ELEMENT_LENGTH = 16
@@ -111,51 +110,6 @@ class EncoderLayer(nn.Module):
111
  attn_outputs = self.norm(attn_outputs)
112
  enc_outputs = self.pos_ffn(attn_outputs)
113
  return enc_outputs, attn
114
-
115
- # class LWM(torch.nn.Module):
116
- # def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
117
- # super().__init__()
118
-
119
- # self.embedding = Embedding(element_length, d_model, max_len)
120
- # self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
121
- # self.linear = nn.Linear(d_model, d_model)
122
- # self.norm = LayerNormalization(d_model)
123
-
124
- # embed_weight = self.embedding.proj.weight
125
- # d_model, n_dim = embed_weight.size()
126
- # self.decoder = nn.Linear(d_model, n_dim, bias=False)
127
- # self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
128
- # self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
129
-
130
- # @classmethod
131
- # def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
132
- # # Define model
133
- # model = cls().to(device)
134
-
135
- # # Download the model weights (from a remote or local repository)
136
- # ckpt_path = f'https://huggingface.co/sadjadalikhani/LWM/resolve/main/{ckpt_name}'
137
-
138
- # # Load the model weights
139
- # model.load_state_dict(torch.hub.load_state_dict_from_url(ckpt_path, map_location=device))
140
- # print(f"Model loaded successfully from {ckpt_path} to {device}")
141
-
142
- # return model
143
-
144
- # def forward(self, input_ids, masked_pos):
145
- # output = self.embedding(input_ids)
146
-
147
- # for layer in self.layers:
148
- # output, _ = layer(output)
149
-
150
- # masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
151
- # h_masked = torch.gather(output, 1, masked_pos)
152
- # h_masked = self.norm(F.relu(self.linear(h_masked)))
153
- # logits_lm = self.decoder(h_masked) + self.decoder_bias
154
-
155
- # return logits_lm, output
156
-
157
- from huggingface_hub import hf_hub_download
158
- import torch
159
 
160
  class LWM(torch.nn.Module):
161
  def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import numpy as np
 
13
  from inference import *
14
  from load_data import load_DeepMIMO_data
15
  from input_preprocess import *
16
+ from huggingface_hub import hf_hub_download
17
 
18
 
19
  ELEMENT_LENGTH = 16
 
110
  attn_outputs = self.norm(attn_outputs)
111
  enc_outputs = self.pos_ffn(attn_outputs)
112
  return enc_outputs, attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  class LWM(torch.nn.Module):
115
  def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):