eneSadi commited on
Commit
225f228
·
unverified ·
1 Parent(s): dda8d50

load model change

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -11,29 +11,33 @@ login(access_token)
11
  model_id = "google/gemma-2-9b-it"
12
  tokenizer = None
13
  model = None
 
14
 
15
  @spaces.GPU
16
  def load_model():
17
- global tokenizer, model
18
- print("Model loading started")
19
- tokenizer = AutoTokenizer.from_pretrained(model_id)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_id,
22
- device_map="auto",
23
- torch_dtype=torch.bfloat16,
24
- )
25
- print("Model loading completed. Device of the model:", model.device)
26
-
27
- load_model()
 
 
 
 
28
 
29
  @spaces.GPU
30
  def ask(prompt):
31
-
32
- global tokenizer, model
33
-
34
  if not prompt:
35
  return {"error": "Prompt is missing"}
36
 
 
 
37
  print("Device of the model:", model.device)
38
  messages = [
39
  {"role": "user", "content": f"{prompt}"},
 
11
  model_id = "google/gemma-2-9b-it"
12
  tokenizer = None
13
  model = None
14
+ model_loaded = False # Flag to check if the model is loaded
15
 
16
  @spaces.GPU
17
  def load_model():
18
+ global tokenizer, model, model_loaded
19
+ if not model_loaded: # Load model only if it's not already loaded
20
+ print("Model loading started")
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ device_map="auto",
25
+ torch_dtype=torch.bfloat16,
26
+ )
27
+ model_loaded = True
28
+ print("Model loading completed. Device of the model:", model.device)
29
+ return model, tokenizer
30
+ else:
31
+ print("Model is already loaded")
32
+ return model, tokenizer
33
 
34
  @spaces.GPU
35
  def ask(prompt):
 
 
 
36
  if not prompt:
37
  return {"error": "Prompt is missing"}
38
 
39
+ if not model_loaded:
40
+ model, tokenizer = load_model() # Ensure the model is loaded before processing
41
  print("Device of the model:", model.device)
42
  messages = [
43
  {"role": "user", "content": f"{prompt}"},