Sompote commited on
Commit
4cf4014
·
verified ·
1 Parent(s): 8d48176

Upload 5 files

Browse files
app2.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ import shap
5
+ import matplotlib.pyplot as plt
6
+ import joblib
7
+ import pandas as pd
8
+ # Load scalers and model
9
+ @st.cache_resource
10
+ def load_resources():
11
+ scaler_X = joblib.load("scaler_X_DS.joblib")
12
+ scaler_y = joblib.load("scaler_y_DS.joblib")
13
+
14
+ model = torch.jit.load("scripted_model_DS.pt")
15
+ model.eval()
16
+
17
+ return scaler_X, scaler_y, model
18
+
19
+ # Create a wrapper function for SHAP
20
+ def model_wrapper(X):
21
+ with torch.no_grad():
22
+ X_tensor = torch.tensor(X, dtype=torch.float32)
23
+ output = model(X_tensor).numpy()
24
+ return scaler_y.inverse_transform(output)
25
+
26
+ # Streamlit app
27
+ st.title("Dynamic Stability Predictor")
28
+
29
+ # Load resources
30
+ scaler_X, scaler_y, model = load_resources()
31
+
32
+ # Define feature names and default values
33
+ feature_names = [
34
+ "25", "19", "12.5", "9.5", "4.75", "2.36", "1.18", "0.6", "0.3", "0.15", "0.075", "CA", "FA", "type"
35
+ ]
36
+ default_values = [100, 100, 81.593, 68.395, 49.318, 29.283, 17.261, 14.257, 6.041, 3.000, 2.115, 0.600, 0.350, 1.0]
37
+
38
+ # Input features
39
+ st.sidebar.header("Input Features")
40
+ input_features = {}
41
+ for feature, default_value in zip(feature_names, default_values):
42
+ if feature == "type":
43
+ type_option = st.sidebar.selectbox(f"Enter {feature}", options=["1 - Limestone", "2 - Basalt"], index=0)
44
+ input_features[feature] = 1.0 if type_option == "1 - Limestone" else 2.0
45
+ else:
46
+ input_features[feature] = st.sidebar.number_input(f"Enter {feature}", value=default_value)
47
+
48
+ # Create input array
49
+ input_array = np.array([input_features[feature] for feature in feature_names]).reshape(1, -1)
50
+ input_scaled = scaler_X.transform(input_array)
51
+
52
+ # Make prediction
53
+ with torch.no_grad():
54
+ prediction = model(torch.tensor(input_scaled, dtype=torch.float32)).numpy()
55
+ prediction_unscaled = scaler_y.inverse_transform(prediction)
56
+
57
+ st.write(f"Predicted Dynamic Stability: {prediction_unscaled[0][0]:.2f} pass/mm")
58
+
59
+ # SHAP explanation
60
+ if st.button("Explain Prediction"):
61
+ # Generate some random background data for SHAP
62
+ background_data = np.random.randn(100, 14) # 100 samples, 14 features
63
+ background_data_scaled = scaler_X.transform(background_data)
64
+
65
+ explainer = shap.KernelExplainer(model_wrapper, background_data_scaled)
66
+ shap_values = explainer.shap_values(input_scaled)
67
+
68
+ shap_values_single = shap_values[0].flatten()
69
+ expected_value = explainer.expected_value[0]
70
+
71
+ feature_values = [f"{x:.1f}" for x in input_array[0]]
72
+
73
+ explanation = shap.Explanation(
74
+ values=shap_values_single,
75
+ base_values=expected_value,
76
+ data=feature_values,
77
+ feature_names=feature_names
78
+ )
79
+
80
+ fig, ax = plt.subplots(figsize=(10, 6))
81
+ shap.plots.waterfall(explanation, show=False)
82
+ st.pyplot(fig)
83
+
84
+ st.write(f"Base value (unscaled): {([[expected_value]])[0][0]:.2f} pass/mm")
85
+ st.write(f"Output value (unscaled): {prediction_unscaled[0][0]:.2f} pass/mm")
86
+
87
+ st.write("\nFeature contributions (unscaled):")
88
+ feature_contributions = pd.DataFrame({
89
+ 'Contribution': shap_values_single
90
+ }, index=feature_names)
91
+ feature_contributions['Contribution'] = feature_contributions['Contribution'].round(4)
92
+ st.table(feature_contributions)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ torch
4
+ shap
5
+ matplotlib
6
+ joblib
scaler_X_DS.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6835ccf71cbd062cc20230e43eb0166f6a8a7bd2f7cd65fc37e657873e778edb
3
+ size 951
scaler_y_DS.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd90832e71904d2eab0440e1f0ff044fbf6329b098215b024febf47c7596e92
3
+ size 623
scripted_model_DS.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa630b0d18893474bfadbe48873d929fbc464aaa7be2b34d1254e266196ea526
3
+ size 1651715