cvejoski commited on
Commit
e9cdb01
·
verified ·
1 Parent(s): aa8da1d

Upload FIMMJP

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. mjp.py +15 -7
  3. model.safetensors +1 -1
config.json CHANGED
@@ -35,7 +35,7 @@
35
  "out_features": 64
36
  },
37
  "torch_dtype": "float32",
38
- "transformers_version": "4.46.1",
39
  "ts_encoder": {
40
  "embed_dim": 64,
41
  "name": "fim.models.blocks.base.TransformerEncoder",
 
35
  "out_features": 64
36
  },
37
  "torch_dtype": "float32",
38
+ "transformers_version": "4.46.0",
39
  "ts_encoder": {
40
  "embed_dim": 64,
41
  "name": "fim.models.blocks.base.TransformerEncoder",
mjp.py CHANGED
@@ -116,7 +116,7 @@ class FIMMJP(AModel):
116
  initial_distribution_decoder["out_features"] = self.n_states
117
  self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
118
 
119
- def forward(self, x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
120
  """
121
  Forward pass for the model.
122
 
@@ -141,20 +141,26 @@ class FIMMJP(AModel):
141
  obs_grid = x["observation_grid"]
142
  if "time_normalization_factors" not in x:
143
  norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
 
 
144
  else:
145
  norm_constants = x["time_normalization_factors"]
 
146
 
147
- obs_values_one_hot = torch.nn.functional.one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.n_states)
148
-
149
- h = self.__encode(x, obs_grid, obs_values_one_hot)
150
 
 
151
  pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
152
 
153
  pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
154
 
155
  out = {
156
- "im": create_matrix_from_off_diagonal(pred_offdiag_im_mean, self.n_states),
157
- "log_var_im": create_matrix_from_off_diagonal(pred_offdiag_im_logvar, self.n_states),
 
 
 
 
158
  "init_cond": init_cond,
159
  }
160
  if "intensity_matrices" in x and "initial_distributions" in x:
@@ -169,7 +175,9 @@ class FIMMJP(AModel):
169
  init_cond = self.initial_distribution_decoder(h)
170
  return pred_offdiag_logmean_logstd, init_cond
171
 
172
- def __encode(self, x: Tensor, obs_grid_normalized: Tensor, obs_values_one_hot: Tensor) -> Tensor:
 
 
173
  B, P, L = obs_grid_normalized.shape[:3]
174
  pos_enc = self.pos_encodings(obs_grid_normalized)
175
  path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
 
116
  initial_distribution_decoder["out_features"] = self.n_states
117
  self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
118
 
119
+ def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
120
  """
121
  Forward pass for the model.
122
 
 
141
  obs_grid = x["observation_grid"]
142
  if "time_normalization_factors" not in x:
143
  norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
144
+ x["time_normalization_factors"] = norm_constants
145
+ x["observation_grid_normalized"] = obs_grid
146
  else:
147
  norm_constants = x["time_normalization_factors"]
148
+ x["observation_grid_normalized"] = obs_grid
149
 
150
+ x["observation_values_one_hot"] = torch.nn.functional.one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.n_states)
 
 
151
 
152
+ h = self.__encode(x)
153
  pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
154
 
155
  pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
156
 
157
  out = {
158
+ "im": create_matrix_from_off_diagonal(
159
+ pred_offdiag_im_mean, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
160
+ ),
161
+ "log_var_im": create_matrix_from_off_diagonal(
162
+ pred_offdiag_im_logvar, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
163
+ ),
164
  "init_cond": init_cond,
165
  }
166
  if "intensity_matrices" in x and "initial_distributions" in x:
 
175
  init_cond = self.initial_distribution_decoder(h)
176
  return pred_offdiag_logmean_logstd, init_cond
177
 
178
+ def __encode(self, x: dict[str, Tensor]) -> Tensor:
179
+ obs_grid_normalized = x["observation_grid_normalized"]
180
+ obs_values_one_hot = x["observation_values_one_hot"]
181
  B, P, L = obs_grid_normalized.shape[:3]
182
  pos_enc = self.pos_encodings(obs_grid_normalized)
183
  path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4f0a742bec9db8d5f4eb153cd53e48055ff740132850b74064ecb6ef76a3b0d4
3
  size 1025616
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97b87e5a321a9c57addcee2d03d9cd7f996f67bdc91f67023eaf495b2d3a885a
3
  size 1025616