sebasfb99 commited on
Commit
4fc19bf
verified
1 Parent(s): d460df7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -33
app.py CHANGED
@@ -15,8 +15,173 @@ def get_popular_tickers():
15
  "JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC"
16
  ]
17
 
18
- # Resto del c贸digo se mantiene igual hasta la secci贸n de la interfaz Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
20
  with gr.Blocks() as demo:
21
  gr.Markdown("# Aplicaci贸n de Predicci贸n de Precios de Acciones")
22
 
@@ -24,47 +189,40 @@ with gr.Blocks() as demo:
24
  with gr.Column(scale=1):
25
  ticker = gr.Dropdown(
26
  choices=get_popular_tickers(),
27
- value="AAPL", # A帽adido valor por defecto
28
- label="Selecciona el S铆mbolo de la Acci贸n"
29
- )
30
- train_data_points = gr.Slider(
31
- minimum=50,
32
- maximum=5000,
33
- value=1000,
34
- step=1,
35
- label="N煤mero de Datos para Entrenamiento"
36
  )
37
- prediction_days = gr.Slider(
38
- minimum=1,
39
- maximum=60,
40
- value=5,
41
- step=1,
42
- label="N煤mero de D铆as a Predecir"
43
- )
44
- predict_btn = gr.Button("Predecir")
 
 
 
 
 
 
 
 
 
 
45
 
46
  with gr.Column():
 
47
  plot_output = gr.Plot(label="Gr谩fico de Predicci贸n")
48
  download_btn = gr.File(label="Descargar Predicciones")
49
 
50
- def update_train_data_points(ticker):
51
- try:
52
- stock = yf.Ticker(ticker)
53
- hist = stock.history(period="max")
54
- total_points = len(hist)
55
- return gr.Slider.update(
56
- maximum=total_points,
57
- value=min(1000, total_points),
58
- visible=True
59
- )
60
- except Exception as e:
61
- print(f"Error updating slider: {str(e)}")
62
- return gr.Slider.update(visible=True) # Mantener slider visible en caso de error
63
-
64
  ticker.change(
65
  fn=update_train_data_points,
66
  inputs=[ticker],
67
- outputs=[train_data_points]
 
68
  )
69
 
70
  predict_btn.click(
 
15
  "JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC"
16
  ]
17
 
18
+ def predict_stock(ticker, train_data_points, prediction_days):
19
+ try:
20
+ # Asegurar que los par谩metros sean enteros
21
+ train_data_points = int(train_data_points)
22
+ prediction_days = int(prediction_days)
23
+
24
+ # Configurar el pipeline
25
+ pipeline = ChronosPipeline.from_pretrained(
26
+ "amazon/chronos-t5-mini",
27
+ device_map="cpu",
28
+ torch_dtype=torch.float32
29
+ )
30
+
31
+ # Obtener datos hist贸ricos
32
+ stock = yf.Ticker(ticker)
33
+ hist = stock.history(period="max")
34
+ if hist.empty:
35
+ raise ValueError(f"No hay datos disponibles para {ticker}")
36
+
37
+ stock_prices = hist[['Close']].reset_index()
38
+ df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'})
39
+
40
+ total_points = len(df)
41
+ if total_points < 50:
42
+ raise ValueError(f"Datos insuficientes para {ticker}")
43
+
44
+ # Asegurar que el n煤mero de datos de entrenamiento no exceda el total disponible
45
+ train_data_points = min(train_data_points, total_points)
46
+
47
+ # Crear el contexto para entrenamiento
48
+ context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32)
49
+
50
+ # Realizar predicci贸n
51
+ forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False)
52
+ low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0)
53
+
54
+ plt.figure(figsize=(20, 10))
55
+ plt.clf()
56
+
57
+ # Determinar el rango de fechas para mostrar
58
+ context_days = min(10, train_data_points)
59
+ start_index = max(0, train_data_points - context_days)
60
+ end_index = min(train_data_points + prediction_days, total_points)
61
+
62
+ # Plotear datos hist贸ricos
63
+ historical_dates = df['Date'][start_index:end_index]
64
+ historical_data = df[f'{ticker}_Close'][start_index:end_index].values
65
+ plt.plot(historical_dates,
66
+ historical_data,
67
+ color='blue',
68
+ linewidth=2,
69
+ label='Datos Reales')
70
+
71
+ # Crear fechas para la predicci贸n
72
+ if train_data_points < total_points:
73
+ prediction_start_date = df['Date'].iloc[train_data_points]
74
+ else:
75
+ last_date = df['Date'].iloc[-1]
76
+ prediction_start_date = last_date + pd.Timedelta(days=1)
77
+
78
+ prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B')
79
+
80
+ # Plotear predicci贸n
81
+ plt.plot(prediction_dates,
82
+ median,
83
+ color='black',
84
+ linewidth=2,
85
+ linestyle='-',
86
+ label='Predicci贸n')
87
+
88
+ # 脕rea de confianza
89
+ plt.fill_between(prediction_dates, low, high,
90
+ color='gray', alpha=0.2,
91
+ label='Intervalo de Confianza')
92
+
93
+ # Calcular m茅tricas si hay datos reales para comparar
94
+ overlap_end_index = train_data_points + prediction_days
95
+ if overlap_end_index <= total_points:
96
+ real_future_dates = df['Date'][train_data_points:overlap_end_index]
97
+ real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
98
+
99
+ matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
100
+ matching_indices = matching_dates.index - train_data_points
101
+ plt.plot(matching_dates,
102
+ real_future_data[matching_indices],
103
+ color='red',
104
+ linewidth=2,
105
+ linestyle='--',
106
+ label='Datos Reales de Validaci贸n')
107
+
108
+ predicted_data = median[:len(matching_indices)]
109
+ mae = mean_absolute_error(real_future_data[matching_indices], predicted_data)
110
+ rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data))
111
+ mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100
112
+ plt.title(f"Predicci贸n del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%",
113
+ fontsize=14, pad=20)
114
+ else:
115
+ plt.title(f"Predicci贸n Futura del Precio de {ticker}",
116
+ fontsize=14, pad=20)
117
+
118
+ plt.legend(loc="upper left", fontsize=12)
119
+ plt.xlabel("Fecha", fontsize=12)
120
+ plt.ylabel("Precio", fontsize=12)
121
+
122
+ plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5)
123
+
124
+ ax = plt.gca()
125
+ locator = mdates.DayLocator()
126
+ formatter = mdates.DateFormatter('%Y-%m-%d')
127
+ ax.xaxis.set_major_locator(locator)
128
+ ax.xaxis.set_major_formatter(formatter)
129
+
130
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
131
+
132
+ plt.tight_layout()
133
+
134
+ # Crear archivo CSV temporal
135
+ temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
136
+ prediction_df = pd.DataFrame({
137
+ 'Date': prediction_dates,
138
+ 'Predicted_Price': median,
139
+ 'Lower_Bound': low,
140
+ 'Upper_Bound': high
141
+ })
142
+
143
+ if overlap_end_index <= total_points:
144
+ real_future_dates = df['Date'][train_data_points:overlap_end_index]
145
+ real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
146
+ matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
147
+ prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
148
+ prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
149
+
150
+ prediction_df.to_csv(temp_csv.name, index=False)
151
+ temp_csv.close()
152
+
153
+ return plt, temp_csv.name
154
+
155
+ except Exception as e:
156
+ print(f"Error: {str(e)}")
157
+ raise gr.Error(f"Error al procesar {ticker}: {str(e)}")
158
+
159
+ def update_train_data_points(ticker):
160
+ if not ticker:
161
+ return gr.Slider.update(value=1000, maximum=5000)
162
+
163
+ try:
164
+ stock = yf.Ticker(ticker)
165
+ hist = stock.history(period="max")
166
+ if hist.empty:
167
+ raise ValueError(f"No hay datos disponibles para {ticker}")
168
+
169
+ total_points = len(hist)
170
+ if total_points < 50:
171
+ raise ValueError(f"Datos insuficientes para {ticker}")
172
+
173
+ return gr.Slider.update(
174
+ maximum=total_points,
175
+ value=min(1000, total_points),
176
+ minimum=50,
177
+ step=1,
178
+ interactive=True
179
+ )
180
+ except Exception as e:
181
+ print(f"Error al actualizar datos para {ticker}: {str(e)}")
182
+ return gr.Slider.update(value=1000, maximum=5000, minimum=50, step=1)
183
 
184
+ # Interfaz de Gradio
185
  with gr.Blocks() as demo:
186
  gr.Markdown("# Aplicaci贸n de Predicci贸n de Precios de Acciones")
187
 
 
189
  with gr.Column(scale=1):
190
  ticker = gr.Dropdown(
191
  choices=get_popular_tickers(),
192
+ value="AAPL",
193
+ label="Selecciona el S铆mbolo de la Acci贸n",
194
+ interactive=True
 
 
 
 
 
 
195
  )
196
+ with gr.Column():
197
+ train_data_points = gr.Slider(
198
+ minimum=50,
199
+ maximum=5000,
200
+ value=1000,
201
+ step=1,
202
+ label="N煤mero de Datos para Entrenamiento",
203
+ interactive=True
204
+ )
205
+ prediction_days = gr.Slider(
206
+ minimum=1,
207
+ maximum=60,
208
+ value=5,
209
+ step=1,
210
+ label="N煤mero de D铆as a Predecir",
211
+ interactive=True
212
+ )
213
+ predict_btn = gr.Button("Predecir", interactive=True)
214
 
215
  with gr.Column():
216
+ error_output = gr.Textbox(label="Estado", visible=False)
217
  plot_output = gr.Plot(label="Gr谩fico de Predicci贸n")
218
  download_btn = gr.File(label="Descargar Predicciones")
219
 
220
+ # Eventos
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  ticker.change(
222
  fn=update_train_data_points,
223
  inputs=[ticker],
224
+ outputs=[train_data_points],
225
+ api_name="update_data"
226
  )
227
 
228
  predict_btn.click(