Upload FIMMJP
Browse files- config.json +2 -1
- mjp.py +108 -38
- model.safetensors +1 -1
config.json
CHANGED
@@ -55,5 +55,6 @@
|
|
55 |
"name": "torch.nn.LSTM"
|
56 |
}
|
57 |
},
|
58 |
-
"use_adjacency_matrix": false
|
|
|
59 |
}
|
|
|
55 |
"name": "torch.nn.LSTM"
|
56 |
}
|
57 |
},
|
58 |
+
"use_adjacency_matrix": false,
|
59 |
+
"use_num_of_paths": true
|
60 |
}
|
mjp.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Any, Dict
|
|
3 |
|
4 |
import torch
|
5 |
from torch import Tensor, nn
|
|
|
6 |
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
7 |
|
8 |
from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder
|
@@ -11,6 +12,30 @@ from fim.utils.helper import create_class_instance
|
|
11 |
|
12 |
|
13 |
class FIMMJPConfig(PretrainedConfig):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
model_type = "fimmjp"
|
15 |
|
16 |
def __init__(
|
@@ -22,6 +47,7 @@ class FIMMJPConfig(PretrainedConfig):
|
|
22 |
path_attention: dict = None,
|
23 |
intensity_matrix_decoder: dict = None,
|
24 |
initial_distribution_decoder: dict = None,
|
|
|
25 |
**kwargs,
|
26 |
):
|
27 |
self.n_states = n_states
|
@@ -31,18 +57,36 @@ class FIMMJPConfig(PretrainedConfig):
|
|
31 |
self.path_attention = path_attention
|
32 |
self.intensity_matrix_decoder = intensity_matrix_decoder
|
33 |
self.initial_distribution_decoder = initial_distribution_decoder
|
|
|
34 |
|
35 |
super().__init__(**kwargs)
|
36 |
|
37 |
|
38 |
class FIMMJP(AModel):
|
39 |
"""
|
40 |
-
FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes
|
41 |
-
This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs) on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper:
|
42 |
-
Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces. These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial. In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs), on bounded state spaces, from noisy and sparse observations, which consists of two components. First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms, with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target MJP in a supervised way. We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion, hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe (i) discrete flashing ratchet systems, which are a type of Brownian motors, and the conformational dynamics in (ii) molecular simulations, (iii) experimental ion channel data and (iv) simple protein folding models. What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets.
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
n_states (int): Number of states in the Markov jump process.
|
47 |
use_adjacency_matrix (bool): Whether to use an adjacency matrix.
|
48 |
ts_encoder (dict | TransformerEncoder): Time series encoder.
|
@@ -53,7 +97,7 @@ class FIMMJP(AModel):
|
|
53 |
gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss.
|
54 |
init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution.
|
55 |
|
56 |
-
Methods
|
57 |
forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
|
58 |
Forward pass of the model.
|
59 |
__decode(h: Tensor) -> tuple[Tensor, Tensor]:
|
@@ -64,7 +108,8 @@ class FIMMJP(AModel):
|
|
64 |
Denormalize the predicted off-diagonal mean and log-variance.
|
65 |
__normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]:
|
66 |
Normalize the observation grid.
|
67 |
-
loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor,
|
|
|
68 |
Compute the loss for the model.
|
69 |
new_stats() -> dict:
|
70 |
Initialize new statistics.
|
@@ -76,16 +121,12 @@ class FIMMJP(AModel):
|
|
76 |
|
77 |
def __init__(self, config: FIMMJPConfig, **kwargs):
|
78 |
super().__init__(config, **kwargs)
|
79 |
-
self.
|
80 |
-
self.use_adjacency_matrix = config.use_adjacency_matrix
|
81 |
-
self.ts_encoder = config.ts_encoder
|
82 |
-
self.total_offdiagonal_transitions = self.n_states**2 - self.n_states
|
83 |
-
|
84 |
-
self.__create_modules()
|
85 |
-
|
86 |
self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none")
|
87 |
self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none")
|
88 |
|
|
|
|
|
89 |
def __create_modules(self):
|
90 |
pos_encodings = copy.deepcopy(self.config.pos_encodings)
|
91 |
ts_encoder = copy.deepcopy(self.config.ts_encoder)
|
@@ -94,26 +135,28 @@ class FIMMJP(AModel):
|
|
94 |
initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder)
|
95 |
|
96 |
if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder":
|
97 |
-
pos_encodings["out_features"] -= self.n_states
|
98 |
self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings)
|
99 |
|
100 |
-
ts_encoder["in_features"] = self.n_states + self.pos_encodings.out_features
|
101 |
self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder)
|
102 |
|
103 |
self.path_attention = create_class_instance(path_attention.pop("name"), path_attention)
|
104 |
|
105 |
in_features = intensity_matrix_decoder.get(
|
106 |
-
"in_features",
|
|
|
107 |
)
|
108 |
intensity_matrix_decoder["in_features"] = in_features
|
109 |
intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions
|
110 |
self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder)
|
111 |
|
112 |
in_features = initial_distribution_decoder.get(
|
113 |
-
"in_features",
|
|
|
114 |
)
|
115 |
initial_distribution_decoder["in_features"] = in_features
|
116 |
-
initial_distribution_decoder["out_features"] = self.n_states
|
117 |
self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
|
118 |
|
119 |
def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
|
@@ -141,41 +184,68 @@ class FIMMJP(AModel):
|
|
141 |
- "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
|
142 |
"""
|
143 |
|
144 |
-
|
145 |
-
if "time_normalization_factors" not in x:
|
146 |
-
norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
|
147 |
-
x["time_normalization_factors"] = norm_constants
|
148 |
-
x["observation_grid_normalized"] = obs_grid
|
149 |
-
else:
|
150 |
-
norm_constants = x["time_normalization_factors"]
|
151 |
-
x["observation_grid_normalized"] = obs_grid
|
152 |
|
153 |
-
x["observation_values_one_hot"] =
|
154 |
|
155 |
h = self.__encode(x)
|
156 |
pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
|
157 |
|
158 |
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
out = {
|
161 |
"intensity_matrices": create_matrix_from_off_diagonal(
|
162 |
-
pred_offdiag_im_mean,
|
|
|
|
|
|
|
163 |
),
|
164 |
"intensity_matrices_variance": create_matrix_from_off_diagonal(
|
165 |
torch.exp(pred_offdiag_im_logvar),
|
166 |
-
self.n_states,
|
167 |
mode="negative_sum_row",
|
168 |
-
n_states=self.n_states if n_states is None else n_states,
|
169 |
),
|
170 |
"initial_condition": init_cond,
|
171 |
}
|
172 |
-
if "intensity_matrices" in x and "initial_distributions" in x:
|
173 |
-
out["losses"] = self.loss(
|
174 |
-
pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step
|
175 |
-
)
|
176 |
|
177 |
return out
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]:
|
180 |
pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h)
|
181 |
init_cond = self.initial_distribution_decoder(h)
|
@@ -200,9 +270,9 @@ class FIMMJP(AModel):
|
|
200 |
last_observation = x["seq_lengths"].view(B * P) - 1
|
201 |
h = h[torch.arange(B * P), last_observation].view(B, P, -1)
|
202 |
h = self.path_attention(h, h, h)
|
203 |
-
|
204 |
-
|
205 |
-
if self.use_adjacency_matrix:
|
206 |
h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1)
|
207 |
return h
|
208 |
|
|
|
3 |
|
4 |
import torch
|
5 |
from torch import Tensor, nn
|
6 |
+
from torch.nn.functional import one_hot
|
7 |
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
8 |
|
9 |
from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder
|
|
|
12 |
|
13 |
|
14 |
class FIMMJPConfig(PretrainedConfig):
|
15 |
+
"""
|
16 |
+
FIMMJPConfig is a configuration class for the FIMMJP model.
|
17 |
+
Attributes:
|
18 |
+
model_type (str): The type of the model, default is "fimmjp".
|
19 |
+
n_states (int): Number of states in the model. Default is 2.
|
20 |
+
use_adjacency_matrix (bool): Whether to use an adjacency matrix. Default is False.
|
21 |
+
ts_encoder (dict): Configuration for the time series encoder. Default is None.
|
22 |
+
pos_encodings (dict): Configuration for the positional encodings. Default is None.
|
23 |
+
path_attention (dict): Configuration for the path attention mechanism. Default is None.
|
24 |
+
intensity_matrix_decoder (dict): Configuration for the intensity matrix decoder. Default is None.
|
25 |
+
initial_distribution_decoder (dict): Configuration for the initial distribution decoder. Default is None.
|
26 |
+
use_num_of_paths (bool): Whether to use the number of paths. Default is True.
|
27 |
+
Args:
|
28 |
+
n_states (int, optional): Number of states in the model. Default is 2.
|
29 |
+
use_adjacency_matrix (bool, optional): Whether to use an adjacency matrix. Default is False.
|
30 |
+
ts_encoder (dict, optional): Configuration for the time series encoder. Default is None.
|
31 |
+
pos_encodings (dict, optional): Configuration for the positional encodings. Default is None.
|
32 |
+
path_attention (dict, optional): Configuration for the path attention mechanism. Default is None.
|
33 |
+
intensity_matrix_decoder (dict, optional): Configuration for the intensity matrix decoder. Default is None.
|
34 |
+
initial_distribution_decoder (dict, optional): Configuration for the initial distribution decoder. Default is None.
|
35 |
+
use_num_of_paths (bool, optional): Whether to use the number of paths. Default is True.
|
36 |
+
**kwargs: Additional keyword arguments.
|
37 |
+
"""
|
38 |
+
|
39 |
model_type = "fimmjp"
|
40 |
|
41 |
def __init__(
|
|
|
47 |
path_attention: dict = None,
|
48 |
intensity_matrix_decoder: dict = None,
|
49 |
initial_distribution_decoder: dict = None,
|
50 |
+
use_num_of_paths: bool = True,
|
51 |
**kwargs,
|
52 |
):
|
53 |
self.n_states = n_states
|
|
|
57 |
self.path_attention = path_attention
|
58 |
self.intensity_matrix_decoder = intensity_matrix_decoder
|
59 |
self.initial_distribution_decoder = initial_distribution_decoder
|
60 |
+
self.use_num_of_paths = use_num_of_paths
|
61 |
|
62 |
super().__init__(**kwargs)
|
63 |
|
64 |
|
65 |
class FIMMJP(AModel):
|
66 |
"""
|
67 |
+
**FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes**
|
|
|
|
|
68 |
|
69 |
+
This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs)
|
70 |
+
on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper:
|
71 |
+
|
72 |
+
Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces.
|
73 |
+
These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial.
|
74 |
+
In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs),
|
75 |
+
on bounded state spaces, from noisy and sparse observations, which consists of two components.
|
76 |
+
|
77 |
+
First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms,
|
78 |
+
with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that
|
79 |
+
processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target
|
80 |
+
MJP in a supervised way.
|
81 |
+
|
82 |
+
We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion,
|
83 |
+
hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe
|
84 |
+
*(i) discrete flashing ratchet systems*, which are a type of Brownian motors, and the conformational dynamics in
|
85 |
+
*(ii) molecular simulations*, *(iii) experimental ion channel data* and *(iv) simple protein folding models*.
|
86 |
+
What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets.
|
87 |
+
|
88 |
+
It is model from the paper: **"Foundation Inference Models for Markov Jump Processes"** --- https://arxiv.org/abs/2406.06419.
|
89 |
+
**Attributes:**
|
90 |
n_states (int): Number of states in the Markov jump process.
|
91 |
use_adjacency_matrix (bool): Whether to use an adjacency matrix.
|
92 |
ts_encoder (dict | TransformerEncoder): Time series encoder.
|
|
|
97 |
gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss.
|
98 |
init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution.
|
99 |
|
100 |
+
**Methods:**
|
101 |
forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict:
|
102 |
Forward pass of the model.
|
103 |
__decode(h: Tensor) -> tuple[Tensor, Tensor]:
|
|
|
108 |
Denormalize the predicted off-diagonal mean and log-variance.
|
109 |
__normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]:
|
110 |
Normalize the observation grid.
|
111 |
+
loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor,
|
112 |
+
adjaceny_matrix: Tensor, normalization_constants: Tensor, schedulers: dict = None, step: int = None) -> dict:
|
113 |
Compute the loss for the model.
|
114 |
new_stats() -> dict:
|
115 |
Initialize new statistics.
|
|
|
121 |
|
122 |
def __init__(self, config: FIMMJPConfig, **kwargs):
|
123 |
super().__init__(config, **kwargs)
|
124 |
+
self.total_offdiagonal_transitions = self.config.n_states**2 - self.config.n_states
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none")
|
126 |
self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none")
|
127 |
|
128 |
+
self.__create_modules()
|
129 |
+
|
130 |
def __create_modules(self):
|
131 |
pos_encodings = copy.deepcopy(self.config.pos_encodings)
|
132 |
ts_encoder = copy.deepcopy(self.config.ts_encoder)
|
|
|
135 |
initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder)
|
136 |
|
137 |
if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder":
|
138 |
+
pos_encodings["out_features"] -= self.config.n_states
|
139 |
self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings)
|
140 |
|
141 |
+
ts_encoder["in_features"] = self.config.n_states + self.pos_encodings.out_features
|
142 |
self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder)
|
143 |
|
144 |
self.path_attention = create_class_instance(path_attention.pop("name"), path_attention)
|
145 |
|
146 |
in_features = intensity_matrix_decoder.get(
|
147 |
+
"in_features",
|
148 |
+
self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.config.use_adjacency_matrix else 1),
|
149 |
)
|
150 |
intensity_matrix_decoder["in_features"] = in_features
|
151 |
intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions
|
152 |
self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder)
|
153 |
|
154 |
in_features = initial_distribution_decoder.get(
|
155 |
+
"in_features",
|
156 |
+
self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.config.use_adjacency_matrix else 1),
|
157 |
)
|
158 |
initial_distribution_decoder["in_features"] = in_features
|
159 |
+
initial_distribution_decoder["out_features"] = self.config.n_states
|
160 |
self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder)
|
161 |
|
162 |
def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict:
|
|
|
184 |
- "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`.
|
185 |
"""
|
186 |
|
187 |
+
norm_constants = self.__normalize_observation_grid(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
x["observation_values_one_hot"] = one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.config.n_states)
|
190 |
|
191 |
h = self.__encode(x)
|
192 |
pred_offdiag_im_mean_logvar, init_cond = self.__decode(h)
|
193 |
|
194 |
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar)
|
195 |
|
196 |
+
out = self.__prepare_output(n_states, init_cond, pred_offdiag_im_mean, pred_offdiag_im_logvar)
|
197 |
+
self.__calculate_train_loss_if_targe_exists(
|
198 |
+
x, schedulers, step, norm_constants, init_cond, pred_offdiag_im_mean, pred_offdiag_im_logvar, out
|
199 |
+
)
|
200 |
+
|
201 |
+
return out
|
202 |
+
|
203 |
+
def __calculate_train_loss_if_targe_exists(
|
204 |
+
self,
|
205 |
+
x: dict[str, Tensor],
|
206 |
+
schedulers: dict,
|
207 |
+
step: int,
|
208 |
+
norm_constants: Tensor,
|
209 |
+
init_cond: Tensor,
|
210 |
+
pred_offdiag_im_mean: Tensor,
|
211 |
+
pred_offdiag_im_logvar: Tensor,
|
212 |
+
out: dict,
|
213 |
+
):
|
214 |
+
if "intensity_matrices" in x and "initial_distributions" in x:
|
215 |
+
out["losses"] = self.loss(
|
216 |
+
pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step
|
217 |
+
)
|
218 |
+
|
219 |
+
def __prepare_output(self, n_states: int, init_cond: Tensor, pred_offdiag_im_mean: Tensor, pred_offdiag_im_logvar: Tensor) -> dict:
|
220 |
out = {
|
221 |
"intensity_matrices": create_matrix_from_off_diagonal(
|
222 |
+
pred_offdiag_im_mean,
|
223 |
+
self.config.n_states,
|
224 |
+
mode="negative_sum_row",
|
225 |
+
n_states=self.config.n_states if n_states is None else n_states,
|
226 |
),
|
227 |
"intensity_matrices_variance": create_matrix_from_off_diagonal(
|
228 |
torch.exp(pred_offdiag_im_logvar),
|
229 |
+
self.config.n_states,
|
230 |
mode="negative_sum_row",
|
231 |
+
n_states=self.config.n_states if n_states is None else n_states,
|
232 |
),
|
233 |
"initial_condition": init_cond,
|
234 |
}
|
|
|
|
|
|
|
|
|
235 |
|
236 |
return out
|
237 |
|
238 |
+
def __normalize_observation_grid(self, x: dict[str, Tensor]) -> Tensor:
|
239 |
+
obs_grid = x["observation_grid"]
|
240 |
+
if "time_normalization_factors" not in x:
|
241 |
+
norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid)
|
242 |
+
x["time_normalization_factors"] = norm_constants
|
243 |
+
x["observation_grid_normalized"] = obs_grid
|
244 |
+
else:
|
245 |
+
norm_constants = x["time_normalization_factors"]
|
246 |
+
x["observation_grid_normalized"] = obs_grid
|
247 |
+
return norm_constants
|
248 |
+
|
249 |
def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]:
|
250 |
pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h)
|
251 |
init_cond = self.initial_distribution_decoder(h)
|
|
|
270 |
last_observation = x["seq_lengths"].view(B * P) - 1
|
271 |
h = h[torch.arange(B * P), last_observation].view(B, P, -1)
|
272 |
h = self.path_attention(h, h, h)
|
273 |
+
if self.config.use_num_of_paths:
|
274 |
+
h = torch.cat([h, torch.ones(B, 1).to(h.device) / 100.0 * P], dim=-1)
|
275 |
+
if self.config.use_adjacency_matrix:
|
276 |
h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1)
|
277 |
return h
|
278 |
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4979384
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3e2ce91304b2441df5bd69e49cf65e0e31f9322d61e79d536aaca80baca7962
|
3 |
size 4979384
|