Upload FIMMJP
Browse files- config.json +1 -1
- mjp.py +15 -7
- model.safetensors +1 -1
config.json
CHANGED
@@ -35,7 +35,7 @@
|
|
35 |
"out_features": 64
|
36 |
},
|
37 |
"torch_dtype": "float32",
|
38 |
-
"transformers_version": "4.46.
|
39 |
"ts_encoder": {
|
40 |
"embed_dim": 64,
|
41 |
"name": "fim.models.blocks.base.TransformerEncoder",
|
|
|
35 |
"out_features": 64
|
36 |
},
|
37 |
"torch_dtype": "float32",
|
38 |
+
"transformers_version": "4.46.0",
|
39 |
"ts_encoder": {
|
40 |
"embed_dim": 64,
|
41 |
"name": "fim.models.blocks.base.TransformerEncoder",
|
mjp.py
CHANGED
@@ -116,7 +116,7 @@ class FIMMJP(AModel):
|
|
116 |
initial_distribution_decoder["out_features"] = self.n_states
|
117 |
self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
|
118 |
|
119 |
-
def forward(self, x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
|
120 |
"""
|
121 |
Forward pass for the model.
|
122 |
|
@@ -141,20 +141,26 @@ class FIMMJP(AModel):
|
|
141 |
obs_grid = x["observation_grid"]
|
142 |
if "time_normalization_factors" not in x:
|
143 |
norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
|
|
|
|
|
144 |
else:
|
145 |
norm_constants = x["time_normalization_factors"]
|
|
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
h = self.__encode(x, obs_grid, obs_values_one_hot)
|
150 |
|
|
|
151 |
pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
|
152 |
|
153 |
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
|
154 |
|
155 |
out = {
|
156 |
-
"im": create_matrix_from_off_diagonal(
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
"init_cond": init_cond,
|
159 |
}
|
160 |
if "intensity_matrices" in x and "initial_distributions" in x:
|
@@ -169,7 +175,9 @@ class FIMMJP(AModel):
|
|
169 |
init_cond = self.initial_distribution_decoder(h)
|
170 |
return pred_offdiag_logmean_logstd, init_cond
|
171 |
|
172 |
-
def __encode(self, x:
|
|
|
|
|
173 |
B, P, L = obs_grid_normalized.shape[:3]
|
174 |
pos_enc = self.pos_encodings(obs_grid_normalized)
|
175 |
path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
|
|
|
116 |
initial_distribution_decoder["out_features"] = self.n_states
|
117 |
self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
|
118 |
|
119 |
+
def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
|
120 |
"""
|
121 |
Forward pass for the model.
|
122 |
|
|
|
141 |
obs_grid = x["observation_grid"]
|
142 |
if "time_normalization_factors" not in x:
|
143 |
norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
|
144 |
+
x["time_normalization_factors"] = norm_constants
|
145 |
+
x["observation_grid_normalized"] = obs_grid
|
146 |
else:
|
147 |
norm_constants = x["time_normalization_factors"]
|
148 |
+
x["observation_grid_normalized"] = obs_grid
|
149 |
|
150 |
+
x["observation_values_one_hot"] = torch.nn.functional.one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.n_states)
|
|
|
|
|
151 |
|
152 |
+
h = self.__encode(x)
|
153 |
pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
|
154 |
|
155 |
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
|
156 |
|
157 |
out = {
|
158 |
+
"im": create_matrix_from_off_diagonal(
|
159 |
+
pred_offdiag_im_mean, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
|
160 |
+
),
|
161 |
+
"log_var_im": create_matrix_from_off_diagonal(
|
162 |
+
pred_offdiag_im_logvar, self.n_states, mode="sum_row", n_states=self.n_states if n_states is None else n_states
|
163 |
+
),
|
164 |
"init_cond": init_cond,
|
165 |
}
|
166 |
if "intensity_matrices" in x and "initial_distributions" in x:
|
|
|
175 |
init_cond = self.initial_distribution_decoder(h)
|
176 |
return pred_offdiag_logmean_logstd, init_cond
|
177 |
|
178 |
+
def __encode(self, x: dict[str, Tensor]) -> Tensor:
|
179 |
+
obs_grid_normalized = x["observation_grid_normalized"]
|
180 |
+
obs_values_one_hot = x["observation_values_one_hot"]
|
181 |
B, P, L = obs_grid_normalized.shape[:3]
|
182 |
pos_enc = self.pos_encodings(obs_grid_normalized)
|
183 |
path = torch.cat([pos_enc, obs_values_one_hot], dim=-1)
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1025616
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97b87e5a321a9c57addcee2d03d9cd7f996f67bdc91f67023eaf495b2d3a885a
|
3 |
size 1025616
|