Cyrano2 commited on
Commit
0c9eb4b
·
verified ·
1 Parent(s): 8ad062b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -1,22 +1,27 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
 
4
 
5
  # Charger le modèle
6
  model_name = "bigcode/starcoder2-15b-instruct-v0.1"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_name,
10
- torch_dtype=torch.float16 # Vous pouvez aussi utiliser torch.float32 pour le CPU
11
- )
12
 
13
  # Fonction pour générer du texte
14
  def generate_text(prompt):
15
- # Utiliser le CPU au lieu du GPU
16
- inputs = tokenizer(prompt, return_tensors="pt")
17
  outputs = model.generate(inputs["input_ids"], max_length=200)
18
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
19
 
 
20
  # Interface utilisateur Gradio
21
  interface = gr.Interface(
22
  fn=generate_text,
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ from accelerate import init_empty_weights
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  # Charger le modèle
8
  model_name = "bigcode/starcoder2-15b-instruct-v0.1"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ # Initialisation conditionnelle
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
16
+ ).to(device)
17
 
18
  # Fonction pour générer du texte
19
  def generate_text(prompt):
20
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
21
  outputs = model.generate(inputs["input_ids"], max_length=200)
22
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
23
 
24
+
25
  # Interface utilisateur Gradio
26
  interface = gr.Interface(
27
  fn=generate_text,