Sadjad Alikhani commited on
Commit
b4f2449
·
verified ·
1 Parent(s): c0addc2

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -29
inference.py CHANGED
@@ -23,29 +23,9 @@ import numpy as np
23
  import warnings
24
  warnings.filterwarnings('ignore')
25
 
26
- def set_seed(seed=42):
27
- torch.manual_seed(seed)
28
- np.random.seed(seed)
29
-
30
- # Use this function at the start of your code
31
- set_seed(42)
32
-
33
- # Force model weights and data to float32 precision
34
- def cast_model_weights_to_float32(model):
35
- for param in model.parameters():
36
- param.data = param.data.float() # Cast all weights to float32
37
- return model
38
-
39
- # Device configuration
40
- device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
41
- if torch.cuda.is_available():
42
- torch.cuda.empty_cache()
43
-
44
- def lwm_inference(preprocessed_chs, input_type, lwm_model):
45
 
46
  dataset = prepare_for_LWM(preprocessed_chs, device)
47
-
48
- lwm_model = cast_model_weights_to_float32(lwm_model)
49
  # Process data through LWM
50
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
51
  print(f'LWM loss: {lwm_loss:.4f}')
@@ -56,15 +36,14 @@ def lwm_inference(preprocessed_chs, input_type, lwm_model):
56
  embedding_data = embedding_data[:, 1:]
57
 
58
  dataset = embedding_data.float()
59
- print(dataset[0][:4])
60
  return dataset
61
 
62
  def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
63
 
64
  input_ids, masked_tokens, masked_pos = zip(*data)
65
 
66
- input_ids_tensor = torch.tensor(input_ids, device=device).float() # Explicitly cast to float32
67
- masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float() # Explicitly cast to float32
68
  masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
69
 
70
  dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
@@ -84,16 +63,13 @@ def evaluate(model, dataloader):
84
  masked_tokens = batch[1]
85
  masked_pos = batch[2]
86
 
87
- if idx == 0:
88
- print(input_ids[0])
89
-
90
  logits_lm, output = model(input_ids, masked_pos)
91
 
92
  output_batch_preproc = output
93
  outputs.append(output_batch_preproc)
94
 
95
  loss_lm = criterionMCM(logits_lm, masked_tokens)
96
- loss = loss_lm / torch.var(masked_tokens) # Use variance for normalization
97
  running_loss += loss.item()
98
 
99
  average_loss = running_loss / len(dataloader)
@@ -104,6 +80,6 @@ def evaluate(model, dataloader):
104
  def create_raw_dataset(data, device):
105
  """Create a dataset for raw channel data."""
106
  input_ids, _, _ = zip(*data)
107
- input_data = torch.tensor(input_ids, device=device).float()[:, 1:] # Explicitly cast to float32
108
  return input_data.float()
109
 
 
23
  import warnings
24
  warnings.filterwarnings('ignore')
25
 
26
+ def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  dataset = prepare_for_LWM(preprocessed_chs, device)
 
 
29
  # Process data through LWM
30
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
31
  print(f'LWM loss: {lwm_loss:.4f}')
 
36
  embedding_data = embedding_data[:, 1:]
37
 
38
  dataset = embedding_data.float()
 
39
  return dataset
40
 
41
  def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
42
 
43
  input_ids, masked_tokens, masked_pos = zip(*data)
44
 
45
+ input_ids_tensor = torch.tensor(input_ids, device=device).float()
46
+ masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
47
  masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
48
 
49
  dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
 
63
  masked_tokens = batch[1]
64
  masked_pos = batch[2]
65
 
 
 
 
66
  logits_lm, output = model(input_ids, masked_pos)
67
 
68
  output_batch_preproc = output
69
  outputs.append(output_batch_preproc)
70
 
71
  loss_lm = criterionMCM(logits_lm, masked_tokens)
72
+ loss = loss_lm / torch.var(masked_tokens)
73
  running_loss += loss.item()
74
 
75
  average_loss = running_loss / len(dataloader)
 
80
  def create_raw_dataset(data, device):
81
  """Create a dataset for raw channel data."""
82
  input_ids, _, _ = zip(*data)
83
+ input_data = torch.tensor(input_ids, device=device)[:, 1:]
84
  return input_data.float()
85