Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +21 -21
modeling_quiet.py
CHANGED
@@ -1254,27 +1254,27 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1254 |
self.eval_mode = False
|
1255 |
|
1256 |
num_talk = 1
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
1264 |
-
|
1265 |
-
|
1266 |
-
|
1267 |
-
|
1268 |
-
|
1269 |
-
|
1270 |
-
|
1271 |
-
|
1272 |
-
|
1273 |
-
|
1274 |
-
|
1275 |
-
|
1276 |
-
|
1277 |
-
|
1278 |
|
1279 |
self.apply(self._init_weights)
|
1280 |
|
|
|
1254 |
self.eval_mode = False
|
1255 |
|
1256 |
num_talk = 1
|
1257 |
+
talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
|
1258 |
+
if self.use_weighted_talk_head:
|
1259 |
+
talk_output_dim = 1
|
1260 |
+
else:
|
1261 |
+
talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size
|
1262 |
+
|
1263 |
+
if not self.merged_lm_and_talk_heads:
|
1264 |
+
if self.use_complex_talk_head:
|
1265 |
+
self.talk_head = nn.ModuleList([nn.Sequential(
|
1266 |
+
nn.Linear(talk_input_dim, config.hidden_size),
|
1267 |
+
nn.ReLU(),
|
1268 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
1269 |
+
nn.ReLU(),
|
1270 |
+
nn.Linear(config.hidden_size, talk_output_dim, bias=False)
|
1271 |
+
)])
|
1272 |
+
else:
|
1273 |
+
self.talk_head = nn.ModuleList([nn.Sequential(
|
1274 |
+
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1275 |
+
)])
|
1276 |
+
|
1277 |
+
self.mixing_head = nn.Linear(config.hidden_size * 2, 1)
|
1278 |
|
1279 |
self.apply(self._init_weights)
|
1280 |
|