Spaces:
Build error
Build error
Shubhayu Majumdar
commited on
Commit
·
c133ff5
1
Parent(s):
c7b4f60
unpaid intern fixes
Browse files
app.py
CHANGED
@@ -1,23 +1,33 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
import yfinance as yf
|
3 |
from datetime import date, timedelta
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
7 |
from Models.datamodels import StockNameModel
|
8 |
-
import matplotlib as mpl
|
9 |
import matplotlib.pyplot as plt
|
10 |
-
import matplotlib.style as style
|
11 |
from matplotlib.dates import date2num, DateFormatter, WeekdayLocator,\
|
12 |
DayLocator, MONDAY
|
13 |
import seaborn as sns
|
14 |
import mplfinance as mpf
|
15 |
from mplfinance.original_flavor import candlestick_ohlc
|
|
|
|
|
16 |
|
17 |
|
18 |
|
19 |
class Stonks:
|
|
|
|
|
20 |
def __init__(self, stocks_filepath: str) -> None:
|
|
|
|
|
|
|
21 |
# Classwise global variables
|
22 |
self.stocks = None
|
23 |
self.selected_stock = None
|
@@ -424,7 +434,25 @@ class Stonks:
|
|
424 |
plt.xticks(rotation = 45)
|
425 |
ax.legend()
|
426 |
return fig
|
|
|
|
|
|
|
|
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
def plot_obv_ema(self, data, title_txt: str):
|
429 |
fig, ax = plt.subplots(figsize=(17, 8))
|
430 |
plt.style.use('ggplot')
|
@@ -447,7 +475,7 @@ class Stonks:
|
|
447 |
ax.set_ylabel('Price', fontsize = 15)
|
448 |
ax.legend(loc = 'upper left')
|
449 |
return fig
|
450 |
-
|
451 |
|
452 |
def ui_renderer(self):
|
453 |
st.title('Stonks 📈')
|
@@ -505,10 +533,12 @@ class Stonks:
|
|
505 |
st.error("Error: No data found for selected stock.")
|
506 |
st.stop()
|
507 |
|
508 |
-
# Download Stock data after fetched
|
509 |
-
st.sidebar.markdown("""---""")
|
510 |
-
# st.sidebar.button("Download Data", self.stock_df.to_csv(f"{self.selected_stock}_data.csv", index=False, header=True, encoding='utf-8-sig'))
|
511 |
-
st.sidebar.download_button(label="Download Data", data=self.stock_df.to_csv(index=False, header=True, encoding='utf-8-sig'), file_name=f"{self.selected_stock}_data.csv", mime='text/csv')
|
|
|
|
|
512 |
|
513 |
st.dataframe(self.stock_df)
|
514 |
|
@@ -550,8 +580,7 @@ class Stonks:
|
|
550 |
|
551 |
The moving average crossover trading strategy will be to take two moving averages - 20-day (fast) and 200-day (slow) - and to go long (buy) when the fast MA goes above the slow MA and to go short (sell) when the fast MA goes below the slow MA.
|
552 |
""")
|
553 |
-
|
554 |
-
|
555 |
temp_df = self.stock_df.copy()
|
556 |
temp_df["20d"] = np.round(temp_df["Adj Close"].rolling(window = 20, center = False).mean(), 2)
|
557 |
temp_df["50d"] = np.round(temp_df["Adj Close"].rolling(window = 50, center = False).mean(), 2)
|
@@ -597,14 +626,14 @@ class Stonks:
|
|
597 |
|
598 |
Single Exponential Smoothing, also known as Simple Exponential Smoothing, is a time series forecasting method for univariate data without a trend or seasonality. It requires an alpha parameter, also called the smoothing factor or smoothing coefficient, to control the rate at which the influence of the observations at prior time steps decay exponentially.
|
599 |
""")
|
600 |
-
st.pyplot(self.plot_exponential_smoothing(self.stock_df["Adj Close"], [0.3, 0.05], label_txt=f"{self.selected_stock}", title_txt=f"Single Exponential Smoothing for {self.selected_stock} stock using 0.05 and 0.3 as alpha values"))
|
601 |
|
602 |
st.markdown("""
|
603 |
The smaller the smoothing factor (coefficient), the smoother the time series will be. As the smoothing factor approaches 0, we approach the moving average model so the smoothing factor of 0.05 produces a smoother time series than 0.3. This indicates slow learning (past observations have a large influence on forecasts). A value close to 1 indicates fast learning (that is, only the most recent values influence the forecasts).
|
604 |
|
605 |
**Double Exponential Smoothing (Holt’s Linear Trend Model)** is an extension being a recursive use of Exponential Smoothing twice where beta is the trend smoothing factor, and takes values between 0 and 1. It explicitly adds support for trends.
|
606 |
""")
|
607 |
-
st.pyplot(self.plot_double_exponential_smoothing(self.stock_df["Adj Close"], alphas=[0.9, 0.02], betas=[0.9, 0.02], label_txt=f"{self.selected_stock}", title_txt=f"Double Exponential Smoothing for {self.selected_stock} stock with different alpha and beta values"))
|
608 |
|
609 |
st.markdown("""
|
610 |
The third main type is Triple Exponential Smoothing (Holt Winters Method) which is an extension of Exponential Smoothing that explicitly adds support for seasonality, or periodic fluctuations.
|
@@ -809,6 +838,7 @@ class Stonks:
|
|
809 |
|
810 |
temp_df = get_roc()
|
811 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
|
|
812 |
st.pyplot(mpf.plot(temp_df, type='candle', style='yahoo', figsize=(15,8), title=f"{self.selected_stock} Daily Price", volume=True))
|
813 |
|
814 |
st.markdown("""
|
@@ -923,13 +953,89 @@ class Stonks:
|
|
923 |
temp_df['Buy'], temp_df['Sell'] = buy_sell_obv(temp_df, 'OBV', 'OBV_EMA')
|
924 |
|
925 |
st.pyplot(self.buy_sell_obv_plot(temp_df, title_txt=f"On Balance Volume Buy and Sell Signals for {self.selected_stock} stock"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
926 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
st.markdown("""---""")
|
928 |
st.markdown("""
|
929 |
## Conclusion
|
930 |
|
931 |
-
|
932 |
-
|
933 |
|
934 |
-
|
935 |
-
stonks.
|
|
|
|
1 |
+
import logging
|
2 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', filename='logger.log')
|
3 |
+
|
4 |
+
|
5 |
+
from typing import List
|
6 |
+
from preds import Predictions
|
7 |
import yfinance as yf
|
8 |
from datetime import date, timedelta
|
9 |
import streamlit as st
|
10 |
import pandas as pd
|
11 |
import numpy as np
|
12 |
from Models.datamodels import StockNameModel
|
|
|
13 |
import matplotlib.pyplot as plt
|
|
|
14 |
from matplotlib.dates import date2num, DateFormatter, WeekdayLocator,\
|
15 |
DayLocator, MONDAY
|
16 |
import seaborn as sns
|
17 |
import mplfinance as mpf
|
18 |
from mplfinance.original_flavor import candlestick_ohlc
|
19 |
+
from lstm import __lstm__
|
20 |
+
|
21 |
|
22 |
|
23 |
|
24 |
class Stonks:
|
25 |
+
|
26 |
+
|
27 |
def __init__(self, stocks_filepath: str) -> None:
|
28 |
+
|
29 |
+
logging.info("Initializing Stonks class")
|
30 |
+
|
31 |
# Classwise global variables
|
32 |
self.stocks = None
|
33 |
self.selected_stock = None
|
|
|
434 |
plt.xticks(rotation = 45)
|
435 |
ax.legend()
|
436 |
return fig
|
437 |
+
|
438 |
+
def plot_lstm_timefm_prediction(self, data, lstm_prediction):
|
439 |
+
fig, ax = plt.subplots(figsize=(17, 8))
|
440 |
+
sns.set_style('whitegrid')
|
441 |
|
442 |
+
ax.plot(data['Adj Close'], label='Actual Close', color='tab:blue', alpha=0.8)
|
443 |
+
ax.plot(data['pred_timesfm'], label='TimeFM Prediction', color='tab:red', alpha=0.8)
|
444 |
+
ax.plot(lstm_prediction, label="LSTM Prediction", color='tab:purple', alpha=0.8)
|
445 |
+
|
446 |
+
ax.set_title("LSTM and TimeFM Predictions vs Actual Close", fontsize=16)
|
447 |
+
ax.set_xlabel("Date", fontsize=14)
|
448 |
+
ax.set_ylabel("Price", fontsize=14)
|
449 |
+
ax.legend(fontsize=12)
|
450 |
+
ax.grid(True, linestyle='--', alpha=0.5)
|
451 |
+
|
452 |
+
plt.xticks(rotation=45)
|
453 |
+
plt.tight_layout()
|
454 |
+
return fig
|
455 |
+
|
456 |
def plot_obv_ema(self, data, title_txt: str):
|
457 |
fig, ax = plt.subplots(figsize=(17, 8))
|
458 |
plt.style.use('ggplot')
|
|
|
475 |
ax.set_ylabel('Price', fontsize = 15)
|
476 |
ax.legend(loc = 'upper left')
|
477 |
return fig
|
478 |
+
|
479 |
|
480 |
def ui_renderer(self):
|
481 |
st.title('Stonks 📈')
|
|
|
533 |
st.error("Error: No data found for selected stock.")
|
534 |
st.stop()
|
535 |
|
536 |
+
# # Download Stock data after fetched
|
537 |
+
# st.sidebar.markdown("""---""")
|
538 |
+
# # st.sidebar.button("Download Data", self.stock_df.to_csv(f"{self.selected_stock}_data.csv", index=False, header=True, encoding='utf-8-sig'))
|
539 |
+
# st.sidebar.download_button(label="Download Data", data=self.stock_df.to_csv(index=False, header=True, encoding='utf-8-sig'), file_name=f"{self.selected_stock}_data.csv", mime='text/csv')
|
540 |
+
|
541 |
+
|
542 |
|
543 |
st.dataframe(self.stock_df)
|
544 |
|
|
|
580 |
|
581 |
The moving average crossover trading strategy will be to take two moving averages - 20-day (fast) and 200-day (slow) - and to go long (buy) when the fast MA goes above the slow MA and to go short (sell) when the fast MA goes below the slow MA.
|
582 |
""")
|
583 |
+
|
|
|
584 |
temp_df = self.stock_df.copy()
|
585 |
temp_df["20d"] = np.round(temp_df["Adj Close"].rolling(window = 20, center = False).mean(), 2)
|
586 |
temp_df["50d"] = np.round(temp_df["Adj Close"].rolling(window = 50, center = False).mean(), 2)
|
|
|
626 |
|
627 |
Single Exponential Smoothing, also known as Simple Exponential Smoothing, is a time series forecasting method for univariate data without a trend or seasonality. It requires an alpha parameter, also called the smoothing factor or smoothing coefficient, to control the rate at which the influence of the observations at prior time steps decay exponentially.
|
628 |
""")
|
629 |
+
# st.pyplot(self.plot_exponential_smoothing(self.stock_df["Adj Close"], [0.3, 0.05], label_txt=f"{self.selected_stock}", title_txt=f"Single Exponential Smoothing for {self.selected_stock} stock using 0.05 and 0.3 as alpha values"))
|
630 |
|
631 |
st.markdown("""
|
632 |
The smaller the smoothing factor (coefficient), the smoother the time series will be. As the smoothing factor approaches 0, we approach the moving average model so the smoothing factor of 0.05 produces a smoother time series than 0.3. This indicates slow learning (past observations have a large influence on forecasts). A value close to 1 indicates fast learning (that is, only the most recent values influence the forecasts).
|
633 |
|
634 |
**Double Exponential Smoothing (Holt’s Linear Trend Model)** is an extension being a recursive use of Exponential Smoothing twice where beta is the trend smoothing factor, and takes values between 0 and 1. It explicitly adds support for trends.
|
635 |
""")
|
636 |
+
# st.pyplot(self.plot_double_exponential_smoothing(self.stock_df["Adj Close"], alphas=[0.9, 0.02], betas=[0.9, 0.02], label_txt=f"{self.selected_stock}", title_txt=f"Double Exponential Smoothing for {self.selected_stock} stock with different alpha and beta values"))
|
637 |
|
638 |
st.markdown("""
|
639 |
The third main type is Triple Exponential Smoothing (Holt Winters Method) which is an extension of Exponential Smoothing that explicitly adds support for seasonality, or periodic fluctuations.
|
|
|
838 |
|
839 |
temp_df = get_roc()
|
840 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
841 |
+
temp_df.index = pd.to_datetime(temp_df.index)
|
842 |
st.pyplot(mpf.plot(temp_df, type='candle', style='yahoo', figsize=(15,8), title=f"{self.selected_stock} Daily Price", volume=True))
|
843 |
|
844 |
st.markdown("""
|
|
|
953 |
temp_df['Buy'], temp_df['Sell'] = buy_sell_obv(temp_df, 'OBV', 'OBV_EMA')
|
954 |
|
955 |
st.pyplot(self.buy_sell_obv_plot(temp_df, title_txt=f"On Balance Volume Buy and Sell Signals for {self.selected_stock} stock"))
|
956 |
+
st.markdown("""---""")
|
957 |
+
|
958 |
+
# Predictions start here
|
959 |
+
st.header("Predictions")
|
960 |
+
st.markdown("""
|
961 |
+
We used TimesFM (200M parameters) and LSTM (66K parameters) for stock price prediction, achieving strong alignment with actual data. TimesFM's zero-shot performance on diverse datasets approached state-of-the-art supervised models. Discrepancies noted are expected to reduce with increased parameter size and further training.
|
962 |
+
""")
|
963 |
+
|
964 |
+
st.subheader("TimesFM (Time Series Foundation Model)")
|
965 |
+
st.markdown("""
|
966 |
+
TimesFM (200M parameters) uses long output patches to reduce error accumulation, enabling accurate long-horizon forecasts. Trained on sequences with varying prediction horizons, it excels in zero-shot predictions across diverse datasets, effectively predicting stock price movements.
|
967 |
+
|
968 |
+
""")
|
969 |
+
# ------------------------------------- Times FM ---------------------------------
|
970 |
+
temp_df = self.stock_df.copy()
|
971 |
+
temp_df.reset_index(inplace=True)
|
972 |
+
temp_df.rename(columns={'index': 'Date'}, inplace=True)
|
973 |
+
|
974 |
+
|
975 |
+
if 'pred' not in st.session_state:
|
976 |
+
st.session_state.pred = Predictions()
|
977 |
+
|
978 |
+
|
979 |
+
if "pred" in st.session_state:
|
980 |
+
|
981 |
+
st.session_state.pred.data_preprocess(
|
982 |
+
data = temp_df,
|
983 |
+
target_colm="Adj Close",
|
984 |
+
date_colm="Date",
|
985 |
+
)
|
986 |
+
|
987 |
+
|
988 |
+
stock_preds = st.session_state.pred.predict()
|
989 |
+
temp_df["pred_timesfm"] = stock_preds
|
990 |
+
|
991 |
|
992 |
+
fig, ax = plt.subplots(figsize=(20, 10))
|
993 |
+
temp_df[["Adj Close", "pred_timesfm"]].plot(ax=ax)
|
994 |
+
ax.set_title(f"{self.selected_stock} Price vs TimesFM Predictions", fontsize=18)
|
995 |
+
ax.set_xlabel("Date", fontsize=14)
|
996 |
+
ax.set_ylabel("Price", fontsize=14)
|
997 |
+
ax.legend(["Actual Price", "Predicted Price"], loc="upper left", fontsize=12)
|
998 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
999 |
+
st.pyplot(fig)
|
1000 |
+
|
1001 |
+
|
1002 |
+
# -------------------------------- LSTM ------------------------------------------
|
1003 |
+
st.subheader("LSTM (Long Short Term Memory)")
|
1004 |
+
st.markdown("""
|
1005 |
+
LSTM, equipped with 66,000 parameters, effectively captures long-term dependencies in stock market data. Its recurrent architecture enables accurate prediction of stock price movements, making it a valuable tool for financial forecasting.
|
1006 |
+
""")
|
1007 |
+
|
1008 |
+
|
1009 |
+
lstm_predictions = __lstm__(temp_df)
|
1010 |
+
|
1011 |
+
fig, ax = plt.subplots(figsize=(20, 10))
|
1012 |
+
temp_df[["Adj Close"]].plot(ax=ax)
|
1013 |
+
ax.plot(lstm_predictions, label = "LSTM", alpha = 0.5)
|
1014 |
+
ax.set_title(f"{self.selected_stock} Price vs LSTM Predictions", fontsize=18)
|
1015 |
+
ax.set_xlabel("Date", fontsize=14)
|
1016 |
+
ax.set_ylabel("Price", fontsize=14)
|
1017 |
+
ax.legend(["Actual Price", "Predicted Price"], loc="upper left", fontsize=12)
|
1018 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
1019 |
+
st.pyplot(fig)
|
1020 |
+
|
1021 |
+
# -------------------------------- TimesFM + LSTM --------------------------------
|
1022 |
+
|
1023 |
+
st.subheader("Comparison: TimesFM vs. LSTM for Stock Price Prediction")
|
1024 |
+
st.markdown("""
|
1025 |
+
While TimesFM utilizes transformer-based architecture with 200M parameters and focuses on capturing complex temporal dependencies for accurate long-horizon forecasts, LSTM, with 66,000 parameters, leverages its recurrent structure to capture long-term dependencies in stock market data, offering effective prediction of price movements.
|
1026 |
+
""")
|
1027 |
+
|
1028 |
+
st.pyplot(self.plot_lstm_timefm_prediction(data = temp_df, lstm_prediction = lstm_predictions))
|
1029 |
+
|
1030 |
+
# --------------------------------------------------------------------------------
|
1031 |
+
|
1032 |
st.markdown("""---""")
|
1033 |
st.markdown("""
|
1034 |
## Conclusion
|
1035 |
|
1036 |
+
In conclusion, the success of stock market analysis relies on combining complementary technical indicators rather than solely relying on uniform signals. This diversification increases the chance of profitable outcomes by forming a robust system. Our comparison between TimesFM and LSTM emphasizes the importance of selecting models based on specific analytical needs. While TimesFM captures complex temporal dependencies effectively, LSTM excels in capturing long-term patterns. By integrating these insights, investors can make more informed decisions and navigate financial markets with greater confidence. """)
|
1037 |
+
|
1038 |
|
1039 |
+
if __name__ == "__main__":
|
1040 |
+
stonks = Stonks(stocks_filepath="Models/stocknames.csv")
|
1041 |
+
stonks.ui_renderer()
|
lstm.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.preprocessing import MinMaxScaler
|
2 |
+
from keras.models import Sequential
|
3 |
+
from keras.layers import Dense, LSTM
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def __lstm__(stock_df):
|
7 |
+
# fig, ax = plt.subplots(figsize=(20, 10))
|
8 |
+
# self.stock_df[["Adj Close", "pred_timesfm"]].plot(ax=ax)
|
9 |
+
# st.pyplot(fig)
|
10 |
+
|
11 |
+
# --------------------------------------- LSTM -----------------------------------
|
12 |
+
|
13 |
+
data = stock_df.filter(['Close'])
|
14 |
+
# Convert the dataframe to a numpy array
|
15 |
+
dataset = data.values
|
16 |
+
# Get the number of rows to train the model on
|
17 |
+
training_data_len = int(np.ceil( len(dataset) * .80 ))
|
18 |
+
|
19 |
+
scaler = MinMaxScaler(feature_range=(0,1))
|
20 |
+
scaled_data = scaler.fit_transform(dataset)
|
21 |
+
|
22 |
+
train_data = scaled_data[0:int(training_data_len), :]
|
23 |
+
x_train = []
|
24 |
+
y_train = []
|
25 |
+
|
26 |
+
for i in range(60, len(train_data)):
|
27 |
+
x_train.append(train_data[i - 60 : i, 0])
|
28 |
+
y_train.append(train_data[i, 0])
|
29 |
+
|
30 |
+
# Convert the x_train and y_train to numpy arrays
|
31 |
+
x_train, y_train = np.array(x_train), np.array(y_train)
|
32 |
+
|
33 |
+
# Reshape the data
|
34 |
+
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
|
35 |
+
|
36 |
+
# Build the LSTM model
|
37 |
+
model = Sequential()
|
38 |
+
# -> (B, 60, 1)
|
39 |
+
model.add(LSTM(128, return_sequences=True, input_shape= (x_train.shape[1], 1), use_bias = False))
|
40 |
+
# -> (B, 60, 128)
|
41 |
+
model.add(LSTM(64, return_sequences=False, use_bias = False))
|
42 |
+
# -> (B, 64)
|
43 |
+
model.add(Dense(25))
|
44 |
+
# -> (B, 25)
|
45 |
+
model.add(Dense(1))
|
46 |
+
# -> (B, 1)
|
47 |
+
|
48 |
+
# Compile the model
|
49 |
+
model.compile(optimizer='adam', loss='mean_squared_error')
|
50 |
+
|
51 |
+
# Train the model
|
52 |
+
model.fit(x_train, y_train, batch_size=1, epochs=1)
|
53 |
+
|
54 |
+
# Create the testing data set
|
55 |
+
test_data = scaled_data[: , :]
|
56 |
+
# Create the data sets x_test and y_test
|
57 |
+
x_test = []
|
58 |
+
y_test = dataset[:, :]
|
59 |
+
for i in range(60, len(test_data)):
|
60 |
+
x_test.append(test_data[i - 60:i, 0])
|
61 |
+
|
62 |
+
# Convert the data to a numpy array
|
63 |
+
x_test = np.array(x_test)
|
64 |
+
|
65 |
+
# Reshape the data
|
66 |
+
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1 ))
|
67 |
+
|
68 |
+
# Get the models predicted price values
|
69 |
+
preds = model.predict(x_test)
|
70 |
+
preds = scaler.inverse_transform(preds)
|
71 |
+
print (preds.shape)
|
72 |
+
predictions = np.full((60,1), np.nan)
|
73 |
+
|
74 |
+
predictions = np.concatenate((predictions, preds), axis=0)
|
75 |
+
print(predictions.shape)
|
76 |
+
return predictions
|
77 |
+
|
78 |
+
# -------------------------------- timefm + LSTM ---------------------------------
|
79 |
+
|
80 |
+
# st.pyplot(plot_lstm_timefm_prediction(data = stock_df, lstm_prediction = predictions))
|
preds.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import timesfm
|
3 |
+
import logging
|
4 |
+
|
5 |
+
|
6 |
+
class Predictions:
|
7 |
+
def __init__(self,
|
8 |
+
context_length: int = 512,
|
9 |
+
horizon_length: int = 14,
|
10 |
+
backend: str = "cpu",
|
11 |
+
checkpoint: str = "google/timesfm-1.0-200m",
|
12 |
+
) -> None:
|
13 |
+
|
14 |
+
logging.info("Initializing Predictions class")
|
15 |
+
|
16 |
+
self.default_step_size = horizon_length
|
17 |
+
|
18 |
+
"""
|
19 |
+
Initialize timesfm model with the following parameters:
|
20 |
+
context_len: The max length of context on which the predictions will be done | Currently supports a max of 512
|
21 |
+
horizon_len: The number of future days for which the predictions will be made | reccomended horizon length <= context length
|
22 |
+
input_patch_len: Fixed value to load 200m model
|
23 |
+
output_patch_len: Fixed value to load 200m model
|
24 |
+
num_layers: Fixed value to load 200m model
|
25 |
+
model_dims: Fixed value to load 200m model
|
26 |
+
backend: The backend to be used for the model | Currently supports "cpu", "gpu", "tpu"
|
27 |
+
"""
|
28 |
+
self.tfm = timesfm.TimesFm(
|
29 |
+
context_len = min(context_length, 512),
|
30 |
+
horizon_len=horizon_length,
|
31 |
+
input_patch_len=32,
|
32 |
+
output_patch_len=128,
|
33 |
+
num_layers=20,
|
34 |
+
model_dims=1280,
|
35 |
+
backend=backend,
|
36 |
+
)
|
37 |
+
|
38 |
+
logging.info("Loading model from checkpoint")
|
39 |
+
self.tfm.load_from_checkpoint(repo_id=checkpoint)
|
40 |
+
|
41 |
+
logging.info("Model loaded successfully")
|
42 |
+
|
43 |
+
|
44 |
+
def data_preprocess(self, data: pd.DataFrame, target_colm: str, date_colm: str) -> None:
|
45 |
+
|
46 |
+
self.data = data.copy()
|
47 |
+
self.target_colm = target_colm
|
48 |
+
self.default_window_size = len(self.data)//10
|
49 |
+
|
50 |
+
self.data["ds"] = pd.to_datetime(self.data[date_colm])
|
51 |
+
# self.data.drop(date_colm, axis=1, inplace=True)
|
52 |
+
self.data = self.data.astype({self.target_colm: float})
|
53 |
+
|
54 |
+
|
55 |
+
def _iter_split(self, current_window: int, step_size: int):
|
56 |
+
window_data = self.data[:current_window]
|
57 |
+
|
58 |
+
if current_window + step_size > len(self.data):
|
59 |
+
step_size = len(self.data) - current_window
|
60 |
+
|
61 |
+
return window_data, step_size
|
62 |
+
|
63 |
+
|
64 |
+
def predict(self, intial_window_size: int = None, step: int = None, freq: str = "D"):
|
65 |
+
window_size = intial_window_size or self.default_window_size
|
66 |
+
step_size = step or self.default_step_size
|
67 |
+
|
68 |
+
# Run iterations and return a pd series of predictions
|
69 |
+
self.data["unique_id"] = 0
|
70 |
+
window = window_size
|
71 |
+
predictions = pd.Series()
|
72 |
+
|
73 |
+
while window < len(self.data):
|
74 |
+
logging.info(f"Predicting for window size: {window}")
|
75 |
+
current_window, step_size = self._iter_split(window, step_size)
|
76 |
+
batch_pred = self.tfm.forecast_on_df(current_window, freq=freq, value_name=self.target_colm)['timesfm']
|
77 |
+
predictions = pd.concat([predictions, batch_pred])
|
78 |
+
window += step_size
|
79 |
+
supp = len(predictions) - (window - window_size)
|
80 |
+
predictions = predictions[:-supp]
|
81 |
+
predictions.index = [i for i in range(window_size, window)]
|
82 |
+
return predictions
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
requirements.txt
CHANGED
@@ -1,8 +1,13 @@
|
|
|
|
1 |
matplotlib
|
2 |
mplfinance
|
3 |
numpy
|
4 |
pandas
|
5 |
pydantic
|
|
|
6 |
seaborn
|
7 |
streamlit
|
8 |
yfinance
|
|
|
|
|
|
|
|
1 |
+
keras
|
2 |
matplotlib
|
3 |
mplfinance
|
4 |
numpy
|
5 |
pandas
|
6 |
pydantic
|
7 |
+
scikit_learn
|
8 |
seaborn
|
9 |
streamlit
|
10 |
yfinance
|
11 |
+
jax[cpu]
|
12 |
+
jaxlib
|
13 |
+
git+https://github.com/google-research/timesfm.git
|