Vipitis commited on
Commit
8a3ef58
·
1 Parent(s): abed9bd

GPU inference

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. requirements.txt +2 -1
  3. utils/generation.py +6 -1
app.py CHANGED
@@ -51,7 +51,7 @@ outro_text ="""
51
  - [] support FIM task for better model context
52
  - [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
53
  - [] gradio examples
54
- - [] use GPU if available, respect memory restrictions.
55
  - [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
56
  - [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
57
  - [] (optional) filtering the dataset by license?
 
51
  - [] support FIM task for better model context
52
  - [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
53
  - [] gradio examples
54
+ - [x] use GPU if available, respect memory restrictions (implemented via accelerate.Accelerator.device in utils.generation.py), tested with A750 successfully!
55
  - [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
56
  - [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
57
  - [] (optional) filtering the dataset by license?
requirements.txt CHANGED
@@ -5,4 +5,5 @@ torch
5
  pillow
6
  gradio
7
  jupylet
8
- tree-sitter
 
 
5
  pillow
6
  gradio
7
  jupylet
8
+ tree-sitter
9
+ accelerate
utils/generation.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import TextIteratorStreamer
2
  from threading import Thread
3
  from .tree_utils import full_func_head, grab_before_comments
@@ -15,17 +16,21 @@ def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, r
15
 
16
 
17
  def stream_generation(prompt:str, pipe, gen_kwargs:dict):
 
 
18
  """
19
  Text generation function
20
  Args:
21
  prompt (str): The context to start generation from.
22
- pipe (Pipeline): The pipeline to use for generation.
23
  gen_kwargs (dict): The generation kwargs.
24
  Returns:
25
  str: The generated text. (it iterates over time)
26
  """
27
  # Tokenize the model_context
28
  model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
 
 
29
 
30
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
31
  # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
 
1
+ from accelerate import Accelerator
2
  from transformers import TextIteratorStreamer
3
  from threading import Thread
4
  from .tree_utils import full_func_head, grab_before_comments
 
16
 
17
 
18
  def stream_generation(prompt:str, pipe, gen_kwargs:dict):
19
+ accelerator = Accelerator()
20
+ device = accelerator.device
21
  """
22
  Text generation function
23
  Args:
24
  prompt (str): The context to start generation from.
25
+ pipe (Pipeline): The pipeline to use for generation (we take the model and tokenizer form it)
26
  gen_kwargs (dict): The generation kwargs.
27
  Returns:
28
  str: The generated text. (it iterates over time)
29
  """
30
  # Tokenize the model_context
31
  model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
32
+ model_inputs.to(device)
33
+ model = pipe.model.to(device) #is this also required?
34
 
35
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
36
  # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.