Sadjad Alikhani commited on
Commit
1484210
·
verified ·
1 Parent(s): 9a7639c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +170 -170
inference.py CHANGED
@@ -1,171 +1,171 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Sun Sep 15 18:27:17 2024
4
-
5
- @author: salikha4
6
- """
7
-
8
- import os
9
- import csv
10
- import json
11
- import shutil
12
- import random
13
- import argparse
14
- from datetime import datetime
15
- import pandas as pd
16
- import time
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from torch.utils.data import Dataset, DataLoader, TensorDataset
21
- from torch.optim import Adam
22
- import numpy as np
23
- from lwm_model import LWM, load_model
24
- import warnings
25
- warnings.filterwarnings('ignore')
26
- from input_preprocess import *
27
-
28
- # Device configuration
29
- device_idx_ds = 3
30
- device = torch.device(f'cuda:{device_idx_ds}' if torch.cuda.is_available() else "cpu")
31
- if torch.cuda.is_available():
32
- torch.cuda.empty_cache()
33
-
34
- # Folders
35
- # MODELS_FOLDER = 'models/'
36
-
37
- def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
38
-
39
- if input_type in ['cls_emb', 'channel_emb']:
40
- dataset = prepare_for_LWM(preprocessed_chs, device)
41
- elif input_type == 'raw':
42
- dataset = create_raw_dataset(preprocessed_chs, device)
43
-
44
- if input_type in ['cls_emb','channel_emb']:
45
- # model = LWM().to(device)
46
- # ckpt_name = 'model_weights.pth'
47
- # ckpt_path = os.path.join(MODELS_FOLDER, ckpt_name)
48
- # lwm_model = load_model(model, ckpt_path, device)
49
- # print(f"Model loaded successfully on {device}")
50
-
51
- # Process data through LWM
52
- lwm_loss, embedding_data = evaluate(lwm_model, dataset)
53
-
54
- print(f'LWM loss: {lwm_loss:.4f}')
55
-
56
- if input_type == 'cls_emb':
57
- embedding_data = embedding_data[:, 0]
58
- elif input_type == 'channel_emb':
59
- embedding_data = embedding_data[:, 1:]
60
-
61
- dataset = embedding_data.float()
62
-
63
- return dataset
64
-
65
-
66
- def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
67
-
68
- input_ids, masked_tokens, masked_pos = zip(*data)
69
-
70
- input_ids_tensor = torch.tensor(input_ids, device=device).float()
71
- masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
72
- masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
73
-
74
- dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
75
-
76
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
77
-
78
-
79
- def create_raw_dataset(data, device):
80
- """Create a dataset for raw channel data."""
81
- input_ids, _, _ = zip(*data)
82
- input_data = torch.tensor(input_ids, device=device)[:, 1:]
83
- return input_data.float()
84
-
85
-
86
- def label_gen(task, data, scenario, n_beams=64):
87
-
88
- idxs = np.where(data['user']['LoS'] != -1)[0]
89
-
90
- if task == 'LoS/NLoS Classification':
91
- label = data['user']['LoS'][idxs]
92
- elif task == 'Beam Prediction':
93
- parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
94
- n_users = len(data['user']['channel'])
95
- n_subbands = 1
96
- fov = 120
97
-
98
- # Setup Beamformers
99
- beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
100
-
101
- F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
102
- phi=azi*np.pi/180,
103
- kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
104
- for azi in beam_angles])
105
-
106
- full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
107
- for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
108
- if data['user']['LoS'][ue_idx] == -1:
109
- full_dbm[:,:,ue_idx] = np.nan
110
- else:
111
- chs = F1 @ data['user']['channel'][ue_idx]
112
- full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
113
- full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
114
-
115
- best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
116
- best_beams = best_beams.astype(float)
117
- best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
118
- max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
119
-
120
- label = best_beams[idxs]
121
-
122
- return label.astype(int)
123
-
124
-
125
- def steering_vec(array, phi=0, theta=0, kd=np.pi):
126
- # phi = azimuth
127
- # theta = elevation
128
- idxs = DeepMIMOv3.ant_indices(array)
129
- resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
130
- return resp / np.linalg.norm(resp)
131
-
132
-
133
- def evaluate(model, dataloader):
134
-
135
- model.eval()
136
- running_loss = 0.0
137
- outputs = []
138
- criterionMCM = nn.MSELoss()
139
-
140
- with torch.no_grad():
141
- for batch in dataloader:
142
- input_ids = batch[0]
143
- masked_tokens = batch[1]
144
- masked_pos = batch[2]
145
-
146
- logits_lm, output = model(input_ids, masked_pos)
147
-
148
- output_batch_preproc = output
149
- outputs.append(output_batch_preproc)
150
-
151
- loss_lm = criterionMCM(logits_lm, masked_tokens)
152
- loss = loss_lm/torch.var(masked_tokens)
153
- running_loss += loss.item()
154
-
155
- average_loss = running_loss / len(dataloader)
156
- output_total = torch.cat(outputs, dim=0)
157
-
158
- return average_loss, output_total
159
-
160
-
161
- def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
162
- labels = []
163
- for scenario_idx in scenario_idxs:
164
- scenario_name = scenarios_list()[scenario_idx]
165
- # data = DeepMIMO_data_gen(scenario_name)
166
- data = deepmimo_data[scenario_idx]
167
- labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
168
-
169
- preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
170
-
171
  return preprocessed_chs
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 18:27:17 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import csv
10
+ import json
11
+ import shutil
12
+ import random
13
+ import argparse
14
+ from datetime import datetime
15
+ import pandas as pd
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
21
+ from torch.optim import Adam
22
+ import numpy as np
23
+ #from lwm_model import LWM, load_model
24
+ import warnings
25
+ warnings.filterwarnings('ignore')
26
+ from input_preprocess import *
27
+
28
+ # Device configuration
29
+ device_idx_ds = 3
30
+ device = torch.device(f'cuda:{device_idx_ds}' if torch.cuda.is_available() else "cpu")
31
+ if torch.cuda.is_available():
32
+ torch.cuda.empty_cache()
33
+
34
+ # Folders
35
+ # MODELS_FOLDER = 'models/'
36
+
37
+ def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
38
+
39
+ if input_type in ['cls_emb', 'channel_emb']:
40
+ dataset = prepare_for_LWM(preprocessed_chs, device)
41
+ elif input_type == 'raw':
42
+ dataset = create_raw_dataset(preprocessed_chs, device)
43
+
44
+ if input_type in ['cls_emb','channel_emb']:
45
+ # model = LWM().to(device)
46
+ # ckpt_name = 'model_weights.pth'
47
+ # ckpt_path = os.path.join(MODELS_FOLDER, ckpt_name)
48
+ # lwm_model = load_model(model, ckpt_path, device)
49
+ # print(f"Model loaded successfully on {device}")
50
+
51
+ # Process data through LWM
52
+ lwm_loss, embedding_data = evaluate(lwm_model, dataset)
53
+
54
+ print(f'LWM loss: {lwm_loss:.4f}')
55
+
56
+ if input_type == 'cls_emb':
57
+ embedding_data = embedding_data[:, 0]
58
+ elif input_type == 'channel_emb':
59
+ embedding_data = embedding_data[:, 1:]
60
+
61
+ dataset = embedding_data.float()
62
+
63
+ return dataset
64
+
65
+
66
+ def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
67
+
68
+ input_ids, masked_tokens, masked_pos = zip(*data)
69
+
70
+ input_ids_tensor = torch.tensor(input_ids, device=device).float()
71
+ masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
72
+ masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
73
+
74
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
75
+
76
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
77
+
78
+
79
+ def create_raw_dataset(data, device):
80
+ """Create a dataset for raw channel data."""
81
+ input_ids, _, _ = zip(*data)
82
+ input_data = torch.tensor(input_ids, device=device)[:, 1:]
83
+ return input_data.float()
84
+
85
+
86
+ def label_gen(task, data, scenario, n_beams=64):
87
+
88
+ idxs = np.where(data['user']['LoS'] != -1)[0]
89
+
90
+ if task == 'LoS/NLoS Classification':
91
+ label = data['user']['LoS'][idxs]
92
+ elif task == 'Beam Prediction':
93
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
94
+ n_users = len(data['user']['channel'])
95
+ n_subbands = 1
96
+ fov = 120
97
+
98
+ # Setup Beamformers
99
+ beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
100
+
101
+ F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
102
+ phi=azi*np.pi/180,
103
+ kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
104
+ for azi in beam_angles])
105
+
106
+ full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
107
+ for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
108
+ if data['user']['LoS'][ue_idx] == -1:
109
+ full_dbm[:,:,ue_idx] = np.nan
110
+ else:
111
+ chs = F1 @ data['user']['channel'][ue_idx]
112
+ full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
113
+ full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
114
+
115
+ best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
116
+ best_beams = best_beams.astype(float)
117
+ best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
118
+ max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
119
+
120
+ label = best_beams[idxs]
121
+
122
+ return label.astype(int)
123
+
124
+
125
+ def steering_vec(array, phi=0, theta=0, kd=np.pi):
126
+ # phi = azimuth
127
+ # theta = elevation
128
+ idxs = DeepMIMOv3.ant_indices(array)
129
+ resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
130
+ return resp / np.linalg.norm(resp)
131
+
132
+
133
+ def evaluate(model, dataloader):
134
+
135
+ model.eval()
136
+ running_loss = 0.0
137
+ outputs = []
138
+ criterionMCM = nn.MSELoss()
139
+
140
+ with torch.no_grad():
141
+ for batch in dataloader:
142
+ input_ids = batch[0]
143
+ masked_tokens = batch[1]
144
+ masked_pos = batch[2]
145
+
146
+ logits_lm, output = model(input_ids, masked_pos)
147
+
148
+ output_batch_preproc = output
149
+ outputs.append(output_batch_preproc)
150
+
151
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
152
+ loss = loss_lm/torch.var(masked_tokens)
153
+ running_loss += loss.item()
154
+
155
+ average_loss = running_loss / len(dataloader)
156
+ output_total = torch.cat(outputs, dim=0)
157
+
158
+ return average_loss, output_total
159
+
160
+
161
+ def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
162
+ labels = []
163
+ for scenario_idx in scenario_idxs:
164
+ scenario_name = scenarios_list()[scenario_idx]
165
+ # data = DeepMIMO_data_gen(scenario_name)
166
+ data = deepmimo_data[scenario_idx]
167
+ labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
168
+
169
+ preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
170
+
171
  return preprocessed_chs