Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,9 @@
|
|
1 |
-
import
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
-
import folium
|
5 |
-
import streamlit as st
|
6 |
-
from streamlit_folium import folium_static
|
7 |
-
import warnings
|
8 |
-
warnings.filterwarnings("ignore")
|
9 |
|
10 |
# Define model paths
|
11 |
-
model_paths = {
|
12 |
'Path': {
|
13 |
'3 hours': 'lr_3H_lat_lon.pkl',
|
14 |
'6 hours': 'lr_6H_lat_lon.pkl',
|
@@ -22,6 +17,11 @@ model_paths = {
|
|
22 |
'30 hours': 'lr_30H_lat_lon.pkl',
|
23 |
'33 hours': 'lr_33H_lat_lon.pkl',
|
24 |
'36 hours': 'lr_36H_lat_lon.pkl'
|
|
|
|
|
|
|
|
|
|
|
25 |
}
|
26 |
}
|
27 |
|
@@ -34,94 +34,131 @@ scaler_paths = {
|
|
34 |
'12 hours': 'lr_12H_lat_lon_scaler.pkl',
|
35 |
'15 hours': 'lr_15H_lat_lon_scaler.pkl',
|
36 |
'18 hours': 'lr_18H_lat_lon_scaler.pkl',
|
|
|
37 |
'24 hours': 'lr_24H_lat_lon_scaler.pkl',
|
38 |
'27 hours': 'lr_27H_lat_lon_scaler.pkl',
|
39 |
'30 hours': 'lr_30H_lat_lon_scaler.pkl',
|
40 |
'33 hours': 'lr_33H_lat_lon_scaler.pkl',
|
41 |
'36 hours': 'lr_36H_lat_lon_scaler.pkl'
|
|
|
|
|
|
|
|
|
|
|
42 |
}
|
43 |
}
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
def process_input(input_data, scaler):
|
52 |
input_data = np.array(input_data).reshape(-1, 7)
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
processed_data = scaler.transform(processed_data)
|
55 |
return processed_data
|
56 |
|
57 |
-
def
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
# Create DataFrame for predictions
|
63 |
-
df_predictions = pd.DataFrame(prediction, columns=['LAT', 'LON'])
|
64 |
-
df_predictions['Time'] = [time_interval]
|
65 |
-
return df_predictions
|
66 |
-
|
67 |
-
# Function to plot predictions on a folium map and return the HTML representation
|
68 |
-
def plot_predictions_on_map(df_predictions):
|
69 |
-
latitudes = df_predictions['LAT'].tolist()
|
70 |
-
longitudes = df_predictions['LON'].tolist()
|
71 |
-
|
72 |
-
m = folium.Map(location=[latitudes[0], longitudes[0]], zoom_start=6)
|
73 |
-
locations = list(zip(latitudes, longitudes))
|
74 |
-
|
75 |
-
for lat, lon in locations:
|
76 |
-
folium.Marker([lat, lon]).add_to(m)
|
77 |
-
|
78 |
-
folium.PolyLine(locations, color='blue', weight=2.5, opacity=0.7).add_to(m)
|
79 |
-
return m
|
80 |
-
|
81 |
-
# Streamlit App
|
82 |
-
def main():
|
83 |
-
st.title("Cyclone Path Prediction")
|
84 |
-
st.write("Input current and previous cyclone data to predict the path and visualize it on a map.")
|
85 |
-
|
86 |
-
# User inputs
|
87 |
-
time_interval = st.selectbox("Select Prediction Time Interval", [
|
88 |
-
'3 hours', '6 hours', '9 hours', '12 hours', '15 hours', '18 hours',
|
89 |
-
'21 hours', '24 hours', '27 hours', '30 hours', '33 hours', '36 hours'
|
90 |
-
])
|
91 |
-
|
92 |
-
previous_lat = st.number_input("Previous Latitude", format="%f")
|
93 |
-
previous_lon = st.number_input("Previous Longitude", format="%f")
|
94 |
-
previous_speed = st.number_input("Previous Speed", format="%f")
|
95 |
-
previous_year = st.number_input("Previous Year", format="%d")
|
96 |
-
previous_month = st.number_input("Previous Month", format="%d")
|
97 |
-
previous_day = st.number_input("Previous Day", format="%d")
|
98 |
-
previous_hour = st.number_input("Previous Hour", format="%d")
|
99 |
-
|
100 |
-
present_lat = st.number_input("Present Latitude", format="%f")
|
101 |
-
present_lon = st.number_input("Present Longitude", format="%f")
|
102 |
-
present_speed = st.number_input("Present Speed", format="%f")
|
103 |
-
present_year = st.number_input("Present Year", format="%d")
|
104 |
-
present_month = st.number_input("Present Month", format="%d")
|
105 |
-
present_day = st.number_input("Present Day", format="%d")
|
106 |
-
present_hour = st.number_input("Present Hour", format="%d")
|
107 |
-
|
108 |
-
if st.button("Predict"):
|
109 |
-
# Process input into array format
|
110 |
-
previous_data = [previous_lat, previous_lon, previous_speed, previous_year, previous_month, previous_day, previous_hour]
|
111 |
-
present_data = [present_lat, present_lon, present_speed, present_year, present_month, present_day, present_hour]
|
112 |
-
input_data = [previous_data, present_data]
|
113 |
|
114 |
-
#
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
import numpy as np
|
3 |
+
import joblib
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Define model paths
|
6 |
+
model_paths = {
|
7 |
'Path': {
|
8 |
'3 hours': 'lr_3H_lat_lon.pkl',
|
9 |
'6 hours': 'lr_6H_lat_lon.pkl',
|
|
|
17 |
'30 hours': 'lr_30H_lat_lon.pkl',
|
18 |
'33 hours': 'lr_33H_lat_lon.pkl',
|
19 |
'36 hours': 'lr_36H_lat_lon.pkl'
|
20 |
+
},
|
21 |
+
'Speed': {
|
22 |
+
'3 hours': 'lgbm_3H_speed.pkl',
|
23 |
+
'15 hours': 'lgbm_15H_speed.pkl',
|
24 |
+
'27 hours': 'lgbm_27H_speed.pkl'
|
25 |
}
|
26 |
}
|
27 |
|
|
|
34 |
'12 hours': 'lr_12H_lat_lon_scaler.pkl',
|
35 |
'15 hours': 'lr_15H_lat_lon_scaler.pkl',
|
36 |
'18 hours': 'lr_18H_lat_lon_scaler.pkl',
|
37 |
+
'21 hours': 'lr_21H_lat_lon_scaler.pkl',
|
38 |
'24 hours': 'lr_24H_lat_lon_scaler.pkl',
|
39 |
'27 hours': 'lr_27H_lat_lon_scaler.pkl',
|
40 |
'30 hours': 'lr_30H_lat_lon_scaler.pkl',
|
41 |
'33 hours': 'lr_33H_lat_lon_scaler.pkl',
|
42 |
'36 hours': 'lr_36H_lat_lon_scaler.pkl'
|
43 |
+
},
|
44 |
+
'Speed': {
|
45 |
+
'3 hours': 'lgbm_speed_scale_3H.pkl',
|
46 |
+
'15 hours': 'lgbm_speed_scale_15H.pkl',
|
47 |
+
'27 hours': 'lgbm_speed_scaler_27H.pkl'
|
48 |
}
|
49 |
}
|
50 |
|
51 |
+
# Define time intervals for each prediction type
|
52 |
+
time_intervals = {
|
53 |
+
'Path': ['3 hours', '6 hours', '9 hours', '12 hours', '15 hours', '18 hours', '21 hours', '24 hours', '27 hours', '30 hours', '33 hours', '36 hours'],
|
54 |
+
'Speed': ['3 hours', '15 hours', '27 hours']
|
55 |
+
}
|
56 |
|
57 |
+
def process_input(input_data, scaler, prediction_type):
|
58 |
input_data = np.array(input_data).reshape(-1, 7)
|
59 |
+
if prediction_type == 'Speed':
|
60 |
+
# For speed prediction, reshape accordingly
|
61 |
+
input_data = input_data[:2].reshape(1, 2, 7)
|
62 |
+
processed_data = input_data.reshape(-1, 14)
|
63 |
+
else: # Path
|
64 |
+
processed_data = input_data[:2].reshape(1, -1)
|
65 |
processed_data = scaler.transform(processed_data)
|
66 |
return processed_data
|
67 |
|
68 |
+
def load_model_and_predict(prediction_type, time_interval, input_data):
|
69 |
+
try:
|
70 |
+
# Load the model and scaler based on user selection
|
71 |
+
model = joblib.load(model_paths[prediction_type][time_interval])
|
72 |
+
scaler = joblib.load(scaler_paths[prediction_type][time_interval])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
# Process input and predict
|
75 |
+
processed_data = process_input(input_data, scaler, prediction_type)
|
76 |
+
prediction = model.predict(processed_data)
|
77 |
+
|
78 |
+
if prediction_type == 'Path':
|
79 |
+
return f"Predicted Path after {time_interval}: Latitude: {prediction[0][0]}, Longitude: {prediction[0][1]}"
|
80 |
+
elif prediction_type == 'Speed':
|
81 |
+
return f"Predicted Speed after {time_interval}: {prediction[0]}"
|
82 |
+
except Exception as e:
|
83 |
+
return str(e)
|
84 |
+
|
85 |
+
# Gradio interface components
|
86 |
+
with gr.Blocks() as cyclone_predictor:
|
87 |
+
gr.Markdown("# Cyclone Path and Speed Prediction App")
|
88 |
+
|
89 |
+
# Dropdown for Prediction Type
|
90 |
+
prediction_type = gr.Dropdown(
|
91 |
+
choices=['Path', 'Speed'],
|
92 |
+
value='Path',
|
93 |
+
label="Select Prediction Type"
|
94 |
+
)
|
95 |
+
|
96 |
+
# Dropdown for Time Interval
|
97 |
+
time_interval = gr.Dropdown(
|
98 |
+
choices=time_intervals['Path'],
|
99 |
+
label="Select Time Interval"
|
100 |
+
)
|
101 |
+
|
102 |
+
# Function to update time intervals based on prediction type
|
103 |
+
def update_time_intervals(prediction_type_value):
|
104 |
+
return gr.update(choices=time_intervals[prediction_type_value])
|
105 |
+
|
106 |
+
# Update time intervals when prediction type changes
|
107 |
+
prediction_type.change(
|
108 |
+
fn=update_time_intervals,
|
109 |
+
inputs=prediction_type,
|
110 |
+
outputs=time_interval
|
111 |
+
)
|
112 |
+
|
113 |
+
# Input fields for user data
|
114 |
+
previous_lat_lon = gr.Textbox(
|
115 |
+
placeholder="Enter previous 3-hour lat/lon (e.g., 15.54,90.64)",
|
116 |
+
label="Previous 3-hour Latitude/Longitude"
|
117 |
+
)
|
118 |
+
previous_speed = gr.Number(label="Previous 3-hour Speed")
|
119 |
+
previous_timestamp = gr.Textbox(
|
120 |
+
placeholder="Enter previous 3-hour timestamp (e.g., 2024,10,23,0)",
|
121 |
+
label="Previous 3-hour Timestamp (year, month, day, hour)"
|
122 |
+
)
|
123 |
+
|
124 |
+
present_lat_lon = gr.Textbox(
|
125 |
+
placeholder="Enter present 3-hour lat/lon (e.g., 15.71,90.29)",
|
126 |
+
label="Present 3-hour Latitude/Longitude"
|
127 |
+
)
|
128 |
+
present_speed = gr.Number(label="Present 3-hour Speed")
|
129 |
+
present_timestamp = gr.Textbox(
|
130 |
+
placeholder="Enter present 3-hour timestamp (e.g., 2024,10,23,3)",
|
131 |
+
label="Present 3-hour Timestamp (year, month, day, hour)"
|
132 |
+
)
|
133 |
+
|
134 |
+
# Output prediction
|
135 |
+
prediction_output = gr.Textbox(label="Prediction Output")
|
136 |
+
|
137 |
+
# Predict button
|
138 |
+
def get_input_data(previous_lat_lon, previous_speed, previous_timestamp, present_lat_lon, present_speed, present_timestamp):
|
139 |
+
try:
|
140 |
+
# Parse inputs into required format
|
141 |
+
prev_lat, prev_lon = map(float, previous_lat_lon.split(','))
|
142 |
+
prev_time = list(map(int, previous_timestamp.split(',')))
|
143 |
+
previous_data = [prev_lat, prev_lon, previous_speed] + prev_time
|
144 |
+
|
145 |
+
present_lat, present_lon = map(float, present_lat_lon.split(','))
|
146 |
+
present_time = list(map(int, present_timestamp.split(',')))
|
147 |
+
present_data = [present_lat, present_lon, present_speed] + present_time
|
148 |
+
|
149 |
+
return [previous_data, present_data]
|
150 |
+
except Exception as e:
|
151 |
+
return str(e)
|
152 |
+
|
153 |
+
predict_button = gr.Button("Predict")
|
154 |
+
|
155 |
+
# Linking function to UI elements
|
156 |
+
predict_button.click(
|
157 |
+
fn=lambda pt, ti, p_lat_lon, p_speed, p_time, c_lat_lon, c_speed, c_time: load_model_and_predict(
|
158 |
+
pt, ti, get_input_data(p_lat_lon, p_speed, p_time, c_lat_lon, c_speed, c_time)
|
159 |
+
),
|
160 |
+
inputs=[prediction_type, time_interval, previous_lat_lon, previous_speed, previous_timestamp, present_lat_lon, present_speed, present_timestamp],
|
161 |
+
outputs=prediction_output
|
162 |
+
)
|
163 |
+
|
164 |
+
cyclone_predictor.launch()
|