cvejoski commited on
Commit
85e192b
·
verified ·
1 Parent(s): afd14e4

Upload FIMMJP

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. mjp.py +108 -38
  3. model.safetensors +1 -1
config.json CHANGED
@@ -55,5 +55,6 @@
55
  "name": "torch.nn.LSTM"
56
  }
57
  },
58
- "use_adjacency_matrix": false
 
59
  }
 
55
  "name": "torch.nn.LSTM"
56
  }
57
  },
58
+ "use_adjacency_matrix": false,
59
+ "use_num_of_paths": true
60
  }
mjp.py CHANGED
@@ -3,6 +3,7 @@ from typing import Any, Dict
3
 
4
  import torch
5
  from torch import Tensor, nn
 
6
  from transformers import AutoConfig, AutoModel, PretrainedConfig
7
 
8
  from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder
@@ -11,6 +12,30 @@ from fim.utils.helper import create_class_instance
11
 
12
 
13
  class FIMMJPConfig(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model_type = "fimmjp"
15
 
16
  def __init__(
@@ -22,6 +47,7 @@ class FIMMJPConfig(PretrainedConfig):
22
  path_attention: dict = None,
23
  intensity_matrix_decoder: dict = None,
24
  initial_distribution_decoder: dict = None,
 
25
  **kwargs,
26
  ):
27
  self.n_states = n_states
@@ -31,18 +57,36 @@ class FIMMJPConfig(PretrainedConfig):
31
  self.path_attention = path_attention
32
  self.intensity_matrix_decoder = intensity_matrix_decoder
33
  self.initial_distribution_decoder = initial_distribution_decoder
 
34
 
35
  super().__init__(**kwargs)
36
 
37
 
38
  class FIMMJP(AModel):
39
  """
40
- FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes
41
- This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs) on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper:
42
- Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces. These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial. In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs), on bounded state spaces, from noisy and sparse observations, which consists of two components. First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms, with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target MJP in a supervised way. We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion, hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe (i) discrete flashing ratchet systems, which are a type of Brownian motors, and the conformational dynamics in (ii) molecular simulations, (iii) experimental ion channel data and (iv) simple protein folding models. What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets.
43
 
44
- It is model from the paper:"Foundation Inference Models for Markov Jump Processes" --- https://arxiv.org/abs/2406.06419.
45
- Attributes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  n_states (int): Number of states in the Markov jump process.
47
  use_adjacency_matrix (bool): Whether to use an adjacency matrix.
48
  ts_encoder (dict | TransformerEncoder): Time series encoder.
@@ -53,7 +97,7 @@ class FIMMJP(AModel):
53
  gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss.
54
  init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution.
55
 
56
- Methods:
57
  forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
58
  Forward pass of the model.
59
  __decode(h: Tensor) -> tuple[Tensor, Tensor]:
@@ -64,7 +108,8 @@ class FIMMJP(AModel):
64
  Denormalize the predicted off-diagonal mean and log-variance.
65
  __normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]:
66
  Normalize the observation grid.
67
- loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor, adjaceny_matrix: Tensor, normalization_constants: Tensor, schedulers: dict = None, step: int = None) -> dict:
 
68
  Compute the loss for the model.
69
  new_stats() -> dict:
70
  Initialize new statistics.
@@ -76,16 +121,12 @@ class FIMMJP(AModel):
76
 
77
  def __init__(self, config: FIMMJPConfig, **kwargs):
78
  super().__init__(config, **kwargs)
79
- self.n_states = config.n_states
80
- self.use_adjacency_matrix = config.use_adjacency_matrix
81
- self.ts_encoder = config.ts_encoder
82
- self.total_offdiagonal_transitions = self.n_states**2 - self.n_states
83
-
84
- self.__create_modules()
85
-
86
  self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none")
87
  self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none")
88
 
 
 
89
  def __create_modules(self):
90
  pos_encodings = copy.deepcopy(self.config.pos_encodings)
91
  ts_encoder = copy.deepcopy(self.config.ts_encoder)
@@ -94,26 +135,28 @@ class FIMMJP(AModel):
94
  initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder)
95
 
96
  if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder":
97
- pos_encodings["out_features"] -= self.n_states
98
  self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings)
99
 
100
- ts_encoder["in_features"] = self.n_states + self.pos_encodings.out_features
101
  self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder)
102
 
103
  self.path_attention = create_class_instance(path_attention.pop("name"), path_attention)
104
 
105
  in_features = intensity_matrix_decoder.get(
106
- "in_features", self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.use_adjacency_matrix else 1)
 
107
  )
108
  intensity_matrix_decoder["in_features"] = in_features
109
  intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions
110
  self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder)
111
 
112
  in_features = initial_distribution_decoder.get(
113
- "in_features", self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.use_adjacency_matrix else 1)
 
114
  )
115
  initial_distribution_decoder["in_features"] = in_features
116
- initial_distribution_decoder["out_features"] = self.n_states
117
  self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
118
 
119
  def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
@@ -141,41 +184,68 @@ class FIMMJP(AModel):
141
  - "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
142
  """
143
 
144
- obs_grid = x["observation_grid"]
145
- if "time_normalization_factors" not in x:
146
- norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
147
- x["time_normalization_factors"] = norm_constants
148
- x["observation_grid_normalized"] = obs_grid
149
- else:
150
- norm_constants = x["time_normalization_factors"]
151
- x["observation_grid_normalized"] = obs_grid
152
 
153
- x["observation_values_one_hot"] = torch.nn.functional.one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.n_states)
154
 
155
  h = self.__encode(x)
156
  pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
157
 
158
  pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  out = {
161
  "intensity_matrices": create_matrix_from_off_diagonal(
162
- pred_offdiag_im_mean, self.n_states, mode="negative_sum_row", n_states=self.n_states if n_states is None else n_states
 
 
 
163
  ),
164
  "intensity_matrices_variance": create_matrix_from_off_diagonal(
165
  torch.exp(pred_offdiag_im_logvar),
166
- self.n_states,
167
  mode="negative_sum_row",
168
- n_states=self.n_states if n_states is None else n_states,
169
  ),
170
  "initial_condition": init_cond,
171
  }
172
- if "intensity_matrices" in x and "initial_distributions" in x:
173
- out["losses"] = self.loss(
174
- pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step
175
- )
176
 
177
  return out
178
 
 
 
 
 
 
 
 
 
 
 
 
179
  def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]:
180
  pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h)
181
  init_cond = self.initial_distribution_decoder(h)
@@ -200,9 +270,9 @@ class FIMMJP(AModel):
200
  last_observation = x["seq_lengths"].view(B * P) - 1
201
  h = h[torch.arange(B * P), last_observation].view(B, P, -1)
202
  h = self.path_attention(h, h, h)
203
-
204
- h = torch.cat([h, torch.ones(B, 1).to(h.device) / 100.0 * P], dim=-1)
205
- if self.use_adjacency_matrix:
206
  h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1)
207
  return h
208
 
 
3
 
4
  import torch
5
  from torch import Tensor, nn
6
+ from torch.nn.functional import one_hot
7
  from transformers import AutoConfig, AutoModel, PretrainedConfig
8
 
9
  from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder
 
12
 
13
 
14
  class FIMMJPConfig(PretrainedConfig):
15
+ """
16
+ FIMMJPConfig is a configuration class for the FIMMJP model.
17
+ Attributes:
18
+ model_type (str): The type of the model, default is "fimmjp".
19
+ n_states (int): Number of states in the model. Default is 2.
20
+ use_adjacency_matrix (bool): Whether to use an adjacency matrix. Default is False.
21
+ ts_encoder (dict): Configuration for the time series encoder. Default is None.
22
+ pos_encodings (dict): Configuration for the positional encodings. Default is None.
23
+ path_attention (dict): Configuration for the path attention mechanism. Default is None.
24
+ intensity_matrix_decoder (dict): Configuration for the intensity matrix decoder. Default is None.
25
+ initial_distribution_decoder (dict): Configuration for the initial distribution decoder. Default is None.
26
+ use_num_of_paths (bool): Whether to use the number of paths. Default is True.
27
+ Args:
28
+ n_states (int, optional): Number of states in the model. Default is 2.
29
+ use_adjacency_matrix (bool, optional): Whether to use an adjacency matrix. Default is False.
30
+ ts_encoder (dict, optional): Configuration for the time series encoder. Default is None.
31
+ pos_encodings (dict, optional): Configuration for the positional encodings. Default is None.
32
+ path_attention (dict, optional): Configuration for the path attention mechanism. Default is None.
33
+ intensity_matrix_decoder (dict, optional): Configuration for the intensity matrix decoder. Default is None.
34
+ initial_distribution_decoder (dict, optional): Configuration for the initial distribution decoder. Default is None.
35
+ use_num_of_paths (bool, optional): Whether to use the number of paths. Default is True.
36
+ **kwargs: Additional keyword arguments.
37
+ """
38
+
39
  model_type = "fimmjp"
40
 
41
  def __init__(
 
47
  path_attention: dict = None,
48
  intensity_matrix_decoder: dict = None,
49
  initial_distribution_decoder: dict = None,
50
+ use_num_of_paths: bool = True,
51
  **kwargs,
52
  ):
53
  self.n_states = n_states
 
57
  self.path_attention = path_attention
58
  self.intensity_matrix_decoder = intensity_matrix_decoder
59
  self.initial_distribution_decoder = initial_distribution_decoder
60
+ self.use_num_of_paths = use_num_of_paths
61
 
62
  super().__init__(**kwargs)
63
 
64
 
65
  class FIMMJP(AModel):
66
  """
67
+ **FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes**
 
 
68
 
69
+ This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs)
70
+ on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper:
71
+
72
+ Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces.
73
+ These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial.
74
+ In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs),
75
+ on bounded state spaces, from noisy and sparse observations, which consists of two components.
76
+
77
+ First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms,
78
+ with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that
79
+ processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target
80
+ MJP in a supervised way.
81
+
82
+ We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion,
83
+ hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe
84
+ *(i) discrete flashing ratchet systems*, which are a type of Brownian motors, and the conformational dynamics in
85
+ *(ii) molecular simulations*, *(iii) experimental ion channel data* and *(iv) simple protein folding models*.
86
+ What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets.
87
+
88
+ It is model from the paper: **"Foundation Inference Models for Markov Jump Processes"** --- https://arxiv.org/abs/2406.06419.
89
+ **Attributes:**
90
  n_states (int): Number of states in the Markov jump process.
91
  use_adjacency_matrix (bool): Whether to use an adjacency matrix.
92
  ts_encoder (dict | TransformerEncoder): Time series encoder.
 
97
  gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss.
98
  init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution.
99
 
100
+ **Methods:**
101
  forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
102
  Forward pass of the model.
103
  __decode(h: Tensor) -> tuple[Tensor, Tensor]:
 
108
  Denormalize the predicted off-diagonal mean and log-variance.
109
  __normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]:
110
  Normalize the observation grid.
111
+ loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor,
112
+ adjaceny_matrix: Tensor, normalization_constants: Tensor, schedulers: dict = None, step: int = None) -> dict:
113
  Compute the loss for the model.
114
  new_stats() -> dict:
115
  Initialize new statistics.
 
121
 
122
  def __init__(self, config: FIMMJPConfig, **kwargs):
123
  super().__init__(config, **kwargs)
124
+ self.total_offdiagonal_transitions = self.config.n_states**2 - self.config.n_states
 
 
 
 
 
 
125
  self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none")
126
  self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none")
127
 
128
+ self.__create_modules()
129
+
130
  def __create_modules(self):
131
  pos_encodings = copy.deepcopy(self.config.pos_encodings)
132
  ts_encoder = copy.deepcopy(self.config.ts_encoder)
 
135
  initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder)
136
 
137
  if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder":
138
+ pos_encodings["out_features"] -= self.config.n_states
139
  self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings)
140
 
141
+ ts_encoder["in_features"] = self.config.n_states + self.pos_encodings.out_features
142
  self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder)
143
 
144
  self.path_attention = create_class_instance(path_attention.pop("name"), path_attention)
145
 
146
  in_features = intensity_matrix_decoder.get(
147
+ "in_features",
148
+ self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.config.use_adjacency_matrix else 1),
149
  )
150
  intensity_matrix_decoder["in_features"] = in_features
151
  intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions
152
  self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder)
153
 
154
  in_features = initial_distribution_decoder.get(
155
+ "in_features",
156
+ self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.config.use_adjacency_matrix else 1),
157
  )
158
  initial_distribution_decoder["in_features"] = in_features
159
+ initial_distribution_decoder["out_features"] = self.config.n_states
160
  self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
161
 
162
  def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
 
184
  - "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
185
  """
186
 
187
+ norm_constants = self.__normalize_observation_grid(x)
 
 
 
 
 
 
 
188
 
189
+ x["observation_values_one_hot"] = one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.config.n_states)
190
 
191
  h = self.__encode(x)
192
  pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
193
 
194
  pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
195
 
196
+ out = self.__prepare_output(n_states, init_cond, pred_offdiag_im_mean, pred_offdiag_im_logvar)
197
+ self.__calculate_train_loss_if_targe_exists(
198
+ x, schedulers, step, norm_constants, init_cond, pred_offdiag_im_mean, pred_offdiag_im_logvar, out
199
+ )
200
+
201
+ return out
202
+
203
+ def __calculate_train_loss_if_targe_exists(
204
+ self,
205
+ x: dict[str, Tensor],
206
+ schedulers: dict,
207
+ step: int,
208
+ norm_constants: Tensor,
209
+ init_cond: Tensor,
210
+ pred_offdiag_im_mean: Tensor,
211
+ pred_offdiag_im_logvar: Tensor,
212
+ out: dict,
213
+ ):
214
+ if "intensity_matrices" in x and "initial_distributions" in x:
215
+ out["losses"] = self.loss(
216
+ pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step
217
+ )
218
+
219
+ def __prepare_output(self, n_states: int, init_cond: Tensor, pred_offdiag_im_mean: Tensor, pred_offdiag_im_logvar: Tensor) -> dict:
220
  out = {
221
  "intensity_matrices": create_matrix_from_off_diagonal(
222
+ pred_offdiag_im_mean,
223
+ self.config.n_states,
224
+ mode="negative_sum_row",
225
+ n_states=self.config.n_states if n_states is None else n_states,
226
  ),
227
  "intensity_matrices_variance": create_matrix_from_off_diagonal(
228
  torch.exp(pred_offdiag_im_logvar),
229
+ self.config.n_states,
230
  mode="negative_sum_row",
231
+ n_states=self.config.n_states if n_states is None else n_states,
232
  ),
233
  "initial_condition": init_cond,
234
  }
 
 
 
 
235
 
236
  return out
237
 
238
+ def __normalize_observation_grid(self, x: dict[str, Tensor]) -> Tensor:
239
+ obs_grid = x["observation_grid"]
240
+ if "time_normalization_factors" not in x:
241
+ norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
242
+ x["time_normalization_factors"] = norm_constants
243
+ x["observation_grid_normalized"] = obs_grid
244
+ else:
245
+ norm_constants = x["time_normalization_factors"]
246
+ x["observation_grid_normalized"] = obs_grid
247
+ return norm_constants
248
+
249
  def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]:
250
  pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h)
251
  init_cond = self.initial_distribution_decoder(h)
 
270
  last_observation = x["seq_lengths"].view(B * P) - 1
271
  h = h[torch.arange(B * P), last_observation].view(B, P, -1)
272
  h = self.path_attention(h, h, h)
273
+ if self.config.use_num_of_paths:
274
+ h = torch.cat([h, torch.ones(B, 1).to(h.device) / 100.0 * P], dim=-1)
275
+ if self.config.use_adjacency_matrix:
276
  h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1)
277
  return h
278
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:932ce52b1ca140a94b92bfe2ec7dea9fa2625c857e146970aa4689a4f361892f
3
  size 4979384
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3e2ce91304b2441df5bd69e49cf65e0e31f9322d61e79d536aaca80baca7962
3
  size 4979384