import gradio as gr import torch from chronos import ChronosPipeline import yfinance as yf import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates from sklearn.metrics import mean_absolute_error, mean_squared_error import tempfile def get_popular_tickers(): return [ "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM", "JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC" ] def predict_stock(ticker, train_data_points, prediction_days): try: # Asegurar que los parámetros sean enteros train_data_points = int(train_data_points) prediction_days = int(prediction_days) # Configurar el pipeline pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-mini", device_map="cpu", torch_dtype=torch.float32 ) # Obtener datos históricos stock = yf.Ticker(ticker) hist = stock.history(period="max") if hist.empty: raise ValueError(f"No hay datos disponibles para {ticker}") stock_prices = hist[['Close']].reset_index() df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'}) total_points = len(df) if total_points < 50: raise ValueError(f"Datos insuficientes para {ticker}") # Asegurar que el número de datos de entrenamiento no exceda el total disponible train_data_points = min(train_data_points, total_points) # Crear el contexto para entrenamiento context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32) # Realizar predicción forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False) low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0) plt.figure(figsize=(20, 10)) plt.clf() # Determinar el rango de fechas para mostrar context_days = min(10, train_data_points) start_index = max(0, train_data_points - context_days) end_index = min(train_data_points + prediction_days, total_points) # Plotear datos históricos historical_dates = df['Date'][start_index:end_index] historical_data = df[f'{ticker}_Close'][start_index:end_index].values plt.plot(historical_dates, historical_data, color='blue', linewidth=2, label='Datos Reales') # Crear fechas para la predicción if train_data_points < total_points: prediction_start_date = df['Date'].iloc[train_data_points] else: last_date = df['Date'].iloc[-1] prediction_start_date = last_date + pd.Timedelta(days=1) prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B') # Plotear predicción plt.plot(prediction_dates, median, color='black', linewidth=2, linestyle='-', label='Predicción') # Área de confianza plt.fill_between(prediction_dates, low, high, color='gray', alpha=0.2, label='Intervalo de Confianza') # Calcular métricas si hay datos reales para comparar overlap_end_index = train_data_points + prediction_days if overlap_end_index <= total_points: real_future_dates = df['Date'][train_data_points:overlap_end_index] real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)] matching_indices = matching_dates.index - train_data_points plt.plot(matching_dates, real_future_data[matching_indices], color='red', linewidth=2, linestyle='--', label='Datos Reales de Validación') predicted_data = median[:len(matching_indices)] mae = mean_absolute_error(real_future_data[matching_indices], predicted_data) rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data)) mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100 plt.title(f"Predicción del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%", fontsize=14, pad=20) else: plt.title(f"Predicción Futura del Precio de {ticker}", fontsize=14, pad=20) plt.legend(loc="upper left", fontsize=12) plt.xlabel("Fecha", fontsize=12) plt.ylabel("Precio", fontsize=12) plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5) ax = plt.gca() locator = mdates.DayLocator() formatter = mdates.DateFormatter('%Y-%m-%d') ax.xaxis.set_major_locator(locator) ax.xaxis.set_major_formatter(formatter) plt.setp(ax.get_xticklabels(), rotation=45, ha='right') plt.tight_layout() # Crear archivo CSV temporal temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') prediction_df = pd.DataFrame({ 'Date': prediction_dates, 'Predicted_Price': median, 'Lower_Bound': low, 'Upper_Bound': high }) if overlap_end_index <= total_points: real_future_dates = df['Date'][train_data_points:overlap_end_index] real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)] prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)] prediction_df['Real_Price'] = real_future_data[:len(prediction_df)] prediction_df.to_csv(temp_csv.name, index=False) temp_csv.close() return plt, temp_csv.name except Exception as e: print(f"Error: {str(e)}") raise gr.Error(f"Error al procesar {ticker}: {str(e)}") def update_train_data_points(ticker): if not ticker: return gr.Slider.update(value=1000, maximum=5000) try: stock = yf.Ticker(ticker) hist = stock.history(period="max") if hist.empty: raise ValueError(f"No hay datos disponibles para {ticker}") total_points = len(hist) if total_points < 50: raise ValueError(f"Datos insuficientes para {ticker}") return gr.Slider.update( maximum=total_points, value=min(1000, total_points), minimum=50, step=1, interactive=True ) except Exception as e: print(f"Error al actualizar datos para {ticker}: {str(e)}") return gr.Slider.update(value=1000, maximum=5000, minimum=50, step=1) # Interfaz de Gradio with gr.Blocks() as demo: gr.Markdown("# Aplicación de Predicción de Precios de Acciones") with gr.Row(): with gr.Column(scale=1): ticker = gr.Dropdown( choices=get_popular_tickers(), value="AAPL", label="Selecciona el Símbolo de la Acción", interactive=True ) with gr.Column(): train_data_points = gr.Slider( minimum=50, maximum=5000, value=1000, step=1, label="Número de Datos para Entrenamiento", interactive=True ) prediction_days = gr.Slider( minimum=1, maximum=60, value=5, step=1, label="Número de Días a Predecir", interactive=True ) predict_btn = gr.Button("Predecir", interactive=True) with gr.Column(): error_output = gr.Textbox(label="Estado", visible=False) plot_output = gr.Plot(label="Gráfico de Predicción") download_btn = gr.File(label="Descargar Predicciones") # Eventos ticker.change( fn=update_train_data_points, inputs=[ticker], outputs=[train_data_points], api_name="update_data" ) predict_btn.click( fn=predict_stock, inputs=[ticker, train_data_points, prediction_days], outputs=[plot_output, download_btn] ) demo.launch()