Sidharthan commited on
Commit
8001965
·
1 Parent(s): a4e95d0

Resolving the configuration problem

Browse files
Files changed (1) hide show
  1. app.py +90 -50
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
3
  from peft import AutoPeftModelForCausalLM
4
  import torch
5
  import re
 
6
  import os
7
 
 
8
  os.environ['HF_HOME'] = '/app/cache'
9
  hf_token = os.getenv('HF_TOKEN')
10
 
@@ -19,34 +21,65 @@ class StopWordCriteria(StoppingCriteria):
19
 
20
  def load_model():
21
  try:
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
23
  if torch.cuda.is_available():
 
24
  st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
25
  else:
26
- st.warning("Using CPU for inference")
 
27
 
 
28
  model_name = "Sidharthan/gemma2_scripter"
29
 
30
- tokenizer = AutoTokenizer.from_pretrained(
31
- model_name,
32
- trust_remote_code=True,
33
- token=hf_token
34
- )
 
 
 
 
 
 
 
35
 
36
- model = AutoPeftModelForCausalLM.from_pretrained(
37
- model_name,
38
- device_map=None,
39
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
- trust_remote_code=True,
41
- low_cpu_mem_usage=True,
42
- cache_dir='/app/cache'
 
 
 
 
43
 
44
- ).to(device)
45
-
46
- return model, tokenizer
 
 
 
 
 
 
 
 
 
47
 
48
  except Exception as e:
49
- st.error(f"Error loading model: {str(e)}")
50
  raise e
51
 
52
  def generate_script(tags, model, tokenizer, params):
@@ -77,6 +110,8 @@ def generate_script(tags, model, tokenizer, params):
77
  stopping_criteria=stopping_criteria
78
  )
79
 
 
 
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
 
82
  # Clean up response
@@ -111,38 +146,43 @@ def main():
111
  def get_model():
112
  return load_model()
113
 
114
- model, tokenizer = get_model()
115
-
116
- # Tag input section
117
- st.markdown("### Add Tags")
118
- st.markdown("Enter tags separated by commas to generate a YouTube script")
119
-
120
- # Create columns for tag input and generate button
121
- col1, col2 = st.columns([3, 1])
122
-
123
- with col1:
124
- tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...")
125
-
126
- with col2:
127
- generate_button = st.button("Generate Script", type="primary")
128
-
129
- # Generated script section
130
- if generate_button and tags:
131
- st.markdown("### Generated Script")
132
- with st.spinner("Generating script..."):
133
- script = generate_script(tags, model, tokenizer, params)
134
- st.text_area("Your script:", value=script, height=400)
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Add download button
137
- st.download_button(
138
- label="Download Script",
139
- data=script,
140
- file_name="youtube_script.txt",
141
- mime="text/plain"
142
- )
143
-
144
- elif generate_button and not tags:
145
- st.warning("Please enter some tags first!")
146
 
147
  if __name__ == "__main__":
148
  main()
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer
3
  from peft import AutoPeftModelForCausalLM
4
  import torch
5
  import re
6
+ from transformers import StoppingCriteria, StoppingCriteriaList
7
  import os
8
 
9
+ # Set cache directory and get token
10
  os.environ['HF_HOME'] = '/app/cache'
11
  hf_token = os.getenv('HF_TOKEN')
12
 
 
21
 
22
  def load_model():
23
  try:
24
+ # Ensure cache directory exists
25
+ cache_dir = '/app/cache'
26
+ os.makedirs(cache_dir, exist_ok=True)
27
+
28
+ # Check for HF token
29
+ if not hf_token:
30
+ st.warning("HuggingFace token not found. Some models may not be accessible.")
31
+
32
+ # Check CUDA availability
33
  if torch.cuda.is_available():
34
+ device = torch.device("cuda")
35
  st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
36
  else:
37
+ device = torch.device("cpu")
38
+ st.warning("CUDA is not available. Using CPU.")
39
 
40
+ # Fine-tuned model for generating scripts
41
  model_name = "Sidharthan/gemma2_scripter"
42
 
43
+ try:
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ model_name,
46
+ trust_remote_code=True,
47
+ token=hf_token,
48
+ cache_dir=cache_dir
49
+ )
50
+ except Exception as e:
51
+ st.error(f"Error loading tokenizer: {str(e)}")
52
+ if "401" in str(e):
53
+ st.error("Authentication error. Please check your HuggingFace token.")
54
+ raise e
55
 
56
+ try:
57
+ # Load model with appropriate device settings
58
+ model = AutoPeftModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ device_map=None, # We'll handle device placement manually
61
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
62
+ trust_remote_code=True,
63
+ low_cpu_mem_usage=True,
64
+ token=hf_token,
65
+ cache_dir=cache_dir
66
+ )
67
 
68
+ # Move model to device
69
+ model = model.to(device)
70
+
71
+ return model, tokenizer
72
+
73
+ except Exception as e:
74
+ st.error(f"Error loading model: {str(e)}")
75
+ if "401" in str(e):
76
+ st.error("Authentication error. Please check your HuggingFace token.")
77
+ elif "disk space" in str(e).lower():
78
+ st.error("Insufficient disk space in cache directory.")
79
+ raise e
80
 
81
  except Exception as e:
82
+ st.error(f"General error during model loading: {str(e)}")
83
  raise e
84
 
85
  def generate_script(tags, model, tokenizer, params):
 
110
  stopping_criteria=stopping_criteria
111
  )
112
 
113
+ # Move outputs back to CPU for decoding
114
+ outputs = outputs.cpu()
115
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
 
117
  # Clean up response
 
146
  def get_model():
147
  return load_model()
148
 
149
+ try:
150
+ model, tokenizer = get_model()
151
+
152
+ # Tag input section
153
+ st.markdown("### Add Tags")
154
+ st.markdown("Enter tags separated by commas to generate a YouTube script")
155
+
156
+ # Create columns for tag input and generate button
157
+ col1, col2 = st.columns([3, 1])
158
+
159
+ with col1:
160
+ tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...")
161
+
162
+ with col2:
163
+ generate_button = st.button("Generate Script", type="primary")
164
+
165
+ # Generated script section
166
+ if generate_button and tags:
167
+ st.markdown("### Generated Script")
168
+ with st.spinner("Generating script..."):
169
+ script = generate_script(tags, model, tokenizer, params)
170
+ st.text_area("Your script:", value=script, height=400)
171
+
172
+ # Add download button
173
+ st.download_button(
174
+ label="Download Script",
175
+ data=script,
176
+ file_name="youtube_script.txt",
177
+ mime="text/plain"
178
+ )
179
+
180
+ elif generate_button and not tags:
181
+ st.warning("Please enter some tags first!")
182
 
183
+ except Exception as e:
184
+ st.error("Failed to initialize the application. Please check the logs for details.")
185
+ st.error(f"Error: {str(e)}")
 
 
 
 
 
 
 
186
 
187
  if __name__ == "__main__":
188
  main()