sebasfb99 commited on
Commit
c3585da
·
verified ·
1 Parent(s): b996b2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -7,6 +7,8 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
  import matplotlib.dates as mdates
9
  from sklearn.metrics import mean_absolute_error, mean_squared_error
 
 
10
  def get_popular_tickers():
11
  return [
12
  "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
@@ -15,6 +17,10 @@ def get_popular_tickers():
15
 
16
  def predict_stock(ticker, train_data_points, prediction_days):
17
  try:
 
 
 
 
18
  # Configurar el pipeline
19
  pipeline = ChronosPipeline.from_pretrained(
20
  "amazon/chronos-t5-mini",
@@ -84,7 +90,6 @@ def predict_stock(ticker, train_data_points, prediction_days):
84
 
85
  # Calcular métricas si hay datos reales para comparar
86
  overlap_end_index = train_data_points + prediction_days
87
- validation_data = None
88
  if overlap_end_index <= total_points:
89
  real_future_dates = df['Date'][train_data_points:overlap_end_index]
90
  real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
@@ -129,7 +134,8 @@ def predict_stock(ticker, train_data_points, prediction_days):
129
 
130
  plt.tight_layout()
131
 
132
- # Crear DataFrame para descarga
 
133
  prediction_df = pd.DataFrame({
134
  'Date': prediction_dates,
135
  'Predicted_Price': median,
@@ -145,10 +151,12 @@ def predict_stock(ticker, train_data_points, prediction_days):
145
  prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
146
  prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
147
 
148
- csv_path = f"predictions_{ticker}.csv"
149
- prediction_df.to_csv(csv_path, index=False)
 
150
 
151
- return plt, csv_path
 
152
 
153
  except Exception as e:
154
  print(f"Error: {str(e)}")
@@ -166,7 +174,7 @@ with gr.Blocks() as demo:
166
  )
167
  train_data_points = gr.Slider(
168
  minimum=50,
169
- maximum=5000, # Puedes ajustar este valor si lo deseas
170
  value=1000,
171
  step=1,
172
  label="Número de Datos para Entrenamiento"
@@ -179,7 +187,7 @@ with gr.Blocks() as demo:
179
  label="Número de Días a Predecir"
180
  )
181
  predict_btn = gr.Button("Predecir")
182
-
183
  with gr.Column():
184
  plot_output = gr.Plot(label="Gráfico de Predicción")
185
  download_btn = gr.File(label="Descargar Predicciones")
@@ -190,7 +198,7 @@ with gr.Blocks() as demo:
190
  hist = stock.history(period="max")
191
  total_points = len(hist)
192
  # Actualizar el deslizador para reflejar el número total de puntos disponibles
193
- return gr.update(maximum=total_points, value=min(1000, total_points))
194
 
195
  ticker.change(
196
  fn=update_train_data_points,
@@ -204,4 +212,4 @@ with gr.Blocks() as demo:
204
  outputs=[plot_output, download_btn]
205
  )
206
 
207
- demo.launch(debug=True)
 
7
  import matplotlib.pyplot as plt
8
  import matplotlib.dates as mdates
9
  from sklearn.metrics import mean_absolute_error, mean_squared_error
10
+ import tempfile
11
+
12
  def get_popular_tickers():
13
  return [
14
  "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
 
17
 
18
  def predict_stock(ticker, train_data_points, prediction_days):
19
  try:
20
+ # Asegurar que los parámetros sean enteros
21
+ train_data_points = int(train_data_points)
22
+ prediction_days = int(prediction_days)
23
+
24
  # Configurar el pipeline
25
  pipeline = ChronosPipeline.from_pretrained(
26
  "amazon/chronos-t5-mini",
 
90
 
91
  # Calcular métricas si hay datos reales para comparar
92
  overlap_end_index = train_data_points + prediction_days
 
93
  if overlap_end_index <= total_points:
94
  real_future_dates = df['Date'][train_data_points:overlap_end_index]
95
  real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
 
134
 
135
  plt.tight_layout()
136
 
137
+ # Crear un archivo temporal para el CSV
138
+ temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
139
  prediction_df = pd.DataFrame({
140
  'Date': prediction_dates,
141
  'Predicted_Price': median,
 
151
  prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
152
  prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
153
 
154
+ # Guardar el DataFrame en el archivo temporal
155
+ prediction_df.to_csv(temp_csv.name, index=False)
156
+ temp_csv.close()
157
 
158
+ # Retornar el gráfico y la ruta del archivo CSV
159
+ return plt, temp_csv.name
160
 
161
  except Exception as e:
162
  print(f"Error: {str(e)}")
 
174
  )
175
  train_data_points = gr.Slider(
176
  minimum=50,
177
+ maximum=5000,
178
  value=1000,
179
  step=1,
180
  label="Número de Datos para Entrenamiento"
 
187
  label="Número de Días a Predecir"
188
  )
189
  predict_btn = gr.Button("Predecir")
190
+
191
  with gr.Column():
192
  plot_output = gr.Plot(label="Gráfico de Predicción")
193
  download_btn = gr.File(label="Descargar Predicciones")
 
198
  hist = stock.history(period="max")
199
  total_points = len(hist)
200
  # Actualizar el deslizador para reflejar el número total de puntos disponibles
201
+ return gr.Slider.update(maximum=total_points, value=min(1000, total_points))
202
 
203
  ticker.change(
204
  fn=update_train_data_points,
 
212
  outputs=[plot_output, download_btn]
213
  )
214
 
215
+ demo.launch()