Sadjad Alikhani
commited on
Update inference.py
Browse files- inference.py +5 -29
inference.py
CHANGED
@@ -23,29 +23,9 @@ import numpy as np
|
|
23 |
import warnings
|
24 |
warnings.filterwarnings('ignore')
|
25 |
|
26 |
-
def
|
27 |
-
torch.manual_seed(seed)
|
28 |
-
np.random.seed(seed)
|
29 |
-
|
30 |
-
# Use this function at the start of your code
|
31 |
-
set_seed(42)
|
32 |
-
|
33 |
-
# Force model weights and data to float32 precision
|
34 |
-
def cast_model_weights_to_float32(model):
|
35 |
-
for param in model.parameters():
|
36 |
-
param.data = param.data.float() # Cast all weights to float32
|
37 |
-
return model
|
38 |
-
|
39 |
-
# Device configuration
|
40 |
-
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
41 |
-
if torch.cuda.is_available():
|
42 |
-
torch.cuda.empty_cache()
|
43 |
-
|
44 |
-
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
45 |
|
46 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
47 |
-
|
48 |
-
lwm_model = cast_model_weights_to_float32(lwm_model)
|
49 |
# Process data through LWM
|
50 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
51 |
print(f'LWM loss: {lwm_loss:.4f}')
|
@@ -56,15 +36,14 @@ def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
|
56 |
embedding_data = embedding_data[:, 1:]
|
57 |
|
58 |
dataset = embedding_data.float()
|
59 |
-
print(dataset[0][:4])
|
60 |
return dataset
|
61 |
|
62 |
def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
|
63 |
|
64 |
input_ids, masked_tokens, masked_pos = zip(*data)
|
65 |
|
66 |
-
input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
67 |
-
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
68 |
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
69 |
|
70 |
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
@@ -84,16 +63,13 @@ def evaluate(model, dataloader):
|
|
84 |
masked_tokens = batch[1]
|
85 |
masked_pos = batch[2]
|
86 |
|
87 |
-
if idx == 0:
|
88 |
-
print(input_ids[0])
|
89 |
-
|
90 |
logits_lm, output = model(input_ids, masked_pos)
|
91 |
|
92 |
output_batch_preproc = output
|
93 |
outputs.append(output_batch_preproc)
|
94 |
|
95 |
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
96 |
-
loss = loss_lm / torch.var(masked_tokens)
|
97 |
running_loss += loss.item()
|
98 |
|
99 |
average_loss = running_loss / len(dataloader)
|
@@ -104,6 +80,6 @@ def evaluate(model, dataloader):
|
|
104 |
def create_raw_dataset(data, device):
|
105 |
"""Create a dataset for raw channel data."""
|
106 |
input_ids, _, _ = zip(*data)
|
107 |
-
input_data = torch.tensor(input_ids, device=device)
|
108 |
return input_data.float()
|
109 |
|
|
|
23 |
import warnings
|
24 |
warnings.filterwarnings('ignore')
|
25 |
|
26 |
+
def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
|
|
|
|
29 |
# Process data through LWM
|
30 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
31 |
print(f'LWM loss: {lwm_loss:.4f}')
|
|
|
36 |
embedding_data = embedding_data[:, 1:]
|
37 |
|
38 |
dataset = embedding_data.float()
|
|
|
39 |
return dataset
|
40 |
|
41 |
def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
|
42 |
|
43 |
input_ids, masked_tokens, masked_pos = zip(*data)
|
44 |
|
45 |
+
input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
46 |
+
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
47 |
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
48 |
|
49 |
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
|
|
63 |
masked_tokens = batch[1]
|
64 |
masked_pos = batch[2]
|
65 |
|
|
|
|
|
|
|
66 |
logits_lm, output = model(input_ids, masked_pos)
|
67 |
|
68 |
output_batch_preproc = output
|
69 |
outputs.append(output_batch_preproc)
|
70 |
|
71 |
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
72 |
+
loss = loss_lm / torch.var(masked_tokens)
|
73 |
running_loss += loss.item()
|
74 |
|
75 |
average_loss = running_loss / len(dataloader)
|
|
|
80 |
def create_raw_dataset(data, device):
|
81 |
"""Create a dataset for raw channel data."""
|
82 |
input_ids, _, _ = zip(*data)
|
83 |
+
input_data = torch.tensor(input_ids, device=device)[:, 1:]
|
84 |
return input_data.float()
|
85 |
|