Crystalcareai commited on
Commit
5e5e800
·
verified ·
1 Parent(s): 1bf3699

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +21 -21
modeling_quiet.py CHANGED
@@ -1254,27 +1254,27 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1254
  self.eval_mode = False
1255
 
1256
  num_talk = 1
1257
- talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1258
- if self.use_weighted_talk_head:
1259
- talk_output_dim = 1
1260
- else:
1261
- talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size
1262
-
1263
- if not self.merged_lm_and_talk_heads:
1264
- if self.use_complex_talk_head:
1265
- self.talk_head = nn.ModuleList([nn.Sequential(
1266
- nn.Linear(talk_input_dim, config.hidden_size),
1267
- nn.ReLU(),
1268
- nn.Linear(config.hidden_size, config.hidden_size),
1269
- nn.ReLU(),
1270
- nn.Linear(config.hidden_size, talk_output_dim, bias=False)
1271
- )])
1272
- else:
1273
- self.talk_head = nn.ModuleList([nn.Sequential(
1274
- nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1275
- )])
1276
-
1277
- self.mixing_head = nn.Linear(config.hidden_size * 2, 1)
1278
 
1279
  self.apply(self._init_weights)
1280
 
 
1254
  self.eval_mode = False
1255
 
1256
  num_talk = 1
1257
+ talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1258
+ if self.use_weighted_talk_head:
1259
+ talk_output_dim = 1
1260
+ else:
1261
+ talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size
1262
+
1263
+ if not self.merged_lm_and_talk_heads:
1264
+ if self.use_complex_talk_head:
1265
+ self.talk_head = nn.ModuleList([nn.Sequential(
1266
+ nn.Linear(talk_input_dim, config.hidden_size),
1267
+ nn.ReLU(),
1268
+ nn.Linear(config.hidden_size, config.hidden_size),
1269
+ nn.ReLU(),
1270
+ nn.Linear(config.hidden_size, talk_output_dim, bias=False)
1271
+ )])
1272
+ else:
1273
+ self.talk_head = nn.ModuleList([nn.Sequential(
1274
+ nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1275
+ )])
1276
+
1277
+ self.mixing_head = nn.Linear(config.hidden_size * 2, 1)
1278
 
1279
  self.apply(self._init_weights)
1280