nxphi47 commited on
Commit
71824cc
·
verified ·
1 Parent(s): 5aad54d

Update multipurpose_chatbot/engines/transformers_engine.py

Browse files
multipurpose_chatbot/engines/transformers_engine.py CHANGED
@@ -1,4 +1,5 @@
1
 
 
2
  import os
3
  import numpy as np
4
  import argparse
@@ -420,7 +421,8 @@ class TransformersEngine(BaseEngine):
420
  self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
421
  print(self._model)
422
  print(f"{self.max_position_embeddings=}")
423
-
 
424
  def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
425
 
426
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
@@ -428,7 +430,7 @@ class TransformersEngine(BaseEngine):
428
  inputs = self.tokenizer(prompt, return_tensors='pt')
429
  num_tokens = inputs.input_ids.size(1)
430
 
431
- inputs = inputs.to(self.device_map)
432
 
433
  generator = self._model.generate(
434
  **inputs,
 
1
 
2
+ import spaces
3
  import os
4
  import numpy as np
5
  import argparse
 
421
  self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
422
  print(self._model)
423
  print(f"{self.max_position_embeddings=}")
424
+
425
+ @spaces.GPU
426
  def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
427
 
428
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
 
430
  inputs = self.tokenizer(prompt, return_tensors='pt')
431
  num_tokens = inputs.input_ids.size(1)
432
 
433
+ inputs = inputs.to(self._model.device)
434
 
435
  generator = self._model.generate(
436
  **inputs,