bwang0911 commited on
Commit
23523c6
·
verified ·
1 Parent(s): d2c8810

fix device and use auto model

Browse files
Files changed (1) hide show
  1. custom_st.py +2 -1
custom_st.py CHANGED
@@ -55,6 +55,7 @@ class Transformer(nn.Module):
55
 
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
  self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
 
58
 
59
  self._lora_adaptations = config.lora_adaptations
60
  if (
@@ -116,7 +117,7 @@ class Transformer(nn.Module):
116
  lora_arguments = (
117
  {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
118
  )
119
- output_states = self.forward(**features, **lora_arguments, return_dict=False)
120
  output_tokens = output_states[0]
121
  features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
122
  return features
 
55
 
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
  self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
58
+ self.device = next(self.auto_model.parameters()).device
59
 
60
  self._lora_adaptations = config.lora_adaptations
61
  if (
 
117
  lora_arguments = (
118
  {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
119
  )
120
+ output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
121
  output_tokens = output_states[0]
122
  features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
123
  return features