-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
273 lines (225 loc) · 13.1 KB
/
app.py
File metadata and controls
273 lines (225 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import streamlit as st
import pandas as pd
import numpy as np
import os
from data_loader import DataLoader
from feature_engineering import FeatureEngineer
from forecast_model import ForecastModel
from event_module import EventModule
from green_score import GreenScore
from explainability import ExplainabilityModule
import matplotlib.pyplot as plt
# Page Config
st.set_page_config(page_title="Voltacore Demand Forecasting", layout="wide")
st.title("Event-Aware Demand Forecasting System")
st.markdown("## Explainable AI & Real-Time Adjustment")
# Sidebar
st.sidebar.header("Configuration")
api_key = st.sidebar.text_input("NewsAPI Key", value="b736fbf9d44443ee9dd83589303e75a8", type="password")
# 1. Data Input
st.subheader("1. Data Loading")
uploaded_file = st.file_uploader("Upload Weekly Sales Data (Excel/CSV)", type=["xlsx", "csv"])
# Automatic loading of Walmart or uploaded file
file_path = None
if uploaded_file:
# Save uploaded file momentarily
ext = uploaded_file.name.split('.')[-1]
file_path = f"temp_data.{ext}"
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
else:
# Check if Walmart.csv exists in parent or current directory
possible_paths = ["../Walmart.csv", "Walmart.csv", "c:/Users/aagma_r95jbd4/sep_ip/Walmart.csv"]
for path in possible_paths:
if os.path.exists(path):
st.info(f"No file uploaded. Using default training data: `{path}`")
file_path = path
break
if file_path:
loader = DataLoader(file_path)
try:
raw_df = loader.load_data()
st.success("Data loaded successfully!")
st.dataframe(raw_df.tail())
# 2. Feature Engineering
fe = FeatureEngineer()
processed_df = fe.create_features(raw_df)
# Train/Test Split
X_train, X_test, y_train, y_test, feature_names = fe.split_data(processed_df)
# 3. Model Training
st.subheader("2. Model Training & Evaluation")
if st.button("Train Model"):
with st.spinner("Training XGBoost Regressor..."):
model = ForecastModel()
model.train(X_train[feature_names], y_train)
metrics = model.evaluate(X_test[feature_names], y_test)
st.write("### Model Performance metrics:")
col1, col2, col3 = st.columns(3)
col1.metric("MAE", f"{metrics['MAE']:.2f}")
col2.metric("RMSE", f"{metrics['RMSE']:.2f}")
col3.metric("MAPE", f"{metrics['MAPE']:.2f}%")
# Store model in session state
st.session_state['model'] = model
st.session_state['feature_names'] = feature_names
st.session_state['last_row'] = processed_df.iloc[-1]
st.session_state['train_data'] = X_train[feature_names]
# 4. Forecasting
if 'model' in st.session_state:
st.markdown("---")
st.subheader("3. Next Week Forecast")
# Inputs for next week scenario
col_a, col_b, col_c = st.columns(3)
future_price = col_a.number_input("Next Week Unit Price ($)", value=float(raw_df['Unit_Price_USD'].mean()))
future_holiday = col_b.checkbox("Next Week is Holiday?", value=False)
future_promo = col_c.checkbox("Next Week Promotion?", value=False)
if st.button("Generate Forecast", type="primary"):
# 1. Identify Time Context
last_row = processed_df.iloc[-1]
last_date = pd.to_datetime(last_row['Date'])
future_date = last_date + pd.Timedelta(weeks=1)
st.info(f"📅 Generating forecast for future date: **{future_date.date()}**")
# 2. Construct Future Features (Crucial Step)
# We need to manually compute the features for this new row based on history
# Lag 1 = Sales of last known week
lag_1 = last_row['Weekly_Sales_Units']
# Lag 2 = Sales of 2nd to last known week (Lag 1 of the last row)
# Check if we have it, else fallback
lag_2 = last_row.get('Lag_1', 0)
# Rolling Mean 4 = Mean of last 4 known sales
# We take the last 4 rows of the ACTUAL data for this
rolling_mean = processed_df['Weekly_Sales_Units'].iloc[-4:].mean()
# Time features
month = future_date.month
week_num = future_date.isocalendar().week
# Assemble Input Dict
input_data = {
'Unit_Price_USD': future_price,
# Fallback to last row for missing inputs implies "Status Quo" assumption
'Unit_Cost_USD': last_row.get('Unit_Cost_USD', 0),
'Current_Inventory_Units': last_row.get('Current_Inventory_Units', 0),
'Transport_Distance_km': last_row.get('Transport_Distance_km', 0),
'Fuel_Price_Index': last_row.get('Fuel_Price_Index', 0),
'Holiday_Flag': 1 if future_holiday else 0,
'Promotion_Flag': 1 if future_promo else 0,
'Month': month,
'Week_Number': week_num,
'Lag_1': lag_1,
'Lag_2': lag_2,
'Rolling_Mean_4': rolling_mean
}
# Dynamically Add any other numeric features that model was trained on
for feature in st.session_state['feature_names']:
if feature not in input_data:
# Use the latest known value for any other external factors (e.g. Unemployment, CPI)
input_data[feature] = last_row.get(feature, 0)
# Ensure DataFrame has correct columns in correct order for model
param_df = pd.DataFrame([input_data])
X_next_week = param_df[st.session_state['feature_names']]
# 3. Base Forecast Prediction
base_forecast = st.session_state['model'].predict(X_next_week)[0]
conf_low, conf_high = st.session_state['model'].get_confidence_interval(base_forecast)
# Visualization of Forecast
st.write(f"### 📉 Forecast for {future_date.date()}: **{base_forecast:,.0f} units**")
# Create a comparison chart
# Compare last 4 Actual weeks + 1 Future Predicted week
last_4_weeks = processed_df.iloc[-4:].copy()
# Prepare data for chart
dates = [str(d.date()) for d in last_4_weeks['Date']] + [str(future_date.date())]
sales = list(last_4_weeks['Weekly_Sales_Units']) + [base_forecast]
types = ['Actual'] * 4 + ['Forecast']
chart_df = pd.DataFrame({
'Week': dates,
'Sales': sales,
'Type': types
})
# Plot
st.line_chart(chart_df.set_index('Week')['Sales'])
st.caption(f"Comparison: Last 4 weeks actuals vs. Forecast for {future_date.date()}. 95% CI: {conf_low:,.0f} - {conf_high:,.0f}")
# 5. Real-Time Event Adjustment
st.markdown("---")
st.subheader("4. Real-Time Event Adjustment")
event_mod = EventModule(api_key)
with st.spinner("Fetching global news..."):
news_texts = event_mod.fetch_news()
impact_data = event_mod.compute_event_impact_score(news_texts)
# Unpack Data
raw_score = impact_data['score']
impact_type = impact_data['type']
sensitivity = impact_data['sensitivity']
adjusted_effect = impact_data['adjusted_effect']
interpretation = impact_data['interpretation']
# Display Metrics
col_e1, col_e2, col_e3 = st.columns(3)
col_e1.metric("Sentiment Score", f"{raw_score:.2f}")
col_e2.metric("Event Type", impact_type)
col_e3.metric("EV Sensitivity", sensitivity)
if adjusted_effect > 0:
st.success(f"✅ {interpretation}")
elif adjusted_effect < 0:
st.warning(f"⚠️ {interpretation}")
else:
st.info(f"ℹ️ {interpretation}")
# Adjustment factor (Alpha = 0.2 positive coefficient, since adjusted_effect handles the sign)
alpha = 0.2
adjustment_factor = 1 + (alpha * adjusted_effect)
adjusted_forecast = base_forecast * adjustment_factor
# Show calculation transparency
with st.expander("📝 View Calculation Details"):
st.markdown(f"""
**Adjustment Formula:**
`Adjusted_Forecast = Base_Forecast × (1 + 0.2 × Adjusted_Effect)`
* **Base Sentiment**: {raw_score:.2f}
* **Event Type**: {impact_type} ({sensitivity} Impact)
* **Adjusted Effect**: {adjusted_effect:.2f} (Effective impact on demand)
* **Final Factor**: {adjustment_factor:.3f}
*Note: "Positive" sensitivity means bad news (e.g. Oil War) increases EV demand.*
""")
col_x, col_y = st.columns(2)
col_x.metric("Adjusted Forecast", f"{adjusted_forecast:,.0f} units",
delta=f"{(adjusted_forecast - base_forecast):,.0f} ({adjusted_effect*20:.1f}%)")
with st.expander("📰 See Top Influencing Headlines"):
if news_texts:
for i, txt in enumerate(news_texts[:5]):
st.write(f"**{i+1}.** {txt[:150]}...")
else:
st.write("No relevant news found.")
# 6. Green Score
st.markdown("---")
st.subheader("5. Sustainability Score")
st.markdown("""
**What is this?** A 0-100 score indicating the environmental efficiency.
* **High Score (80-100)**: Excellent efficiency.
* **Low Score (<50)**: High carbon footprint or waste.
""")
gs = GreenScore()
green_score = gs.compute_score(
input_data.get('Transport_Distance_km', 0),
input_data.get('Fuel_Price_Index', 0),
input_data.get('Current_Inventory_Units', 0),
adjusted_forecast
)
recommendation = gs.get_recommendation(green_score)
col_g1, col_g2 = st.columns([1, 2])
col_g1.metric("Green Score", f"{green_score}/100")
col_g2.info(f"💡 Recommendation: {recommendation}")
# 7. Explainability
st.markdown("---")
st.subheader("6. Forecast Explanation (SHAP)")
st.markdown("### 🔍 Why did the model predict this?")
st.info("""
This chart ranks the factors driving your sales forecast from most important (top) to least important.
**How to read this for your business:**
1. **Holiday_Flag**: If this is at the top, it means **Seasonality/Holidays** are the biggest driver of your sales.
* *Red dots to the right* → Holidays significantly boost your sales.
2. **Unit_Price_USD**: If this is high up, your customers are very **price-sensitive**.
* *Red dots onto the left* → Higher prices sharply reduce demand (Elastic demand).
* *Blue dots onto the right* → Lower prices drive volume.
3. **Lag_1**: This represents **Momentum**. If top-ranked, your sales trend is sticky (last week strongly predicts this week).
4. **Fuel_Price / Unemployment**: These capture **Macroeconomic** impact.
*The 'SHAP Value' on the bottom axis represents the actual units added or subtracted from the forecast by that factor.*
""")
explainer = ExplainabilityModule(st.session_state['model'].model)
st.pyplot(explainer.plot_shap_summary(st.session_state['train_data']))
st.write("**Visual Guide**: The width of the spread shows the variability of impact. A wide spread means that feature makes a huge difference depending on its value.")
except Exception as e:
st.error(f"Error processing file: {e}")