Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions tutorials/GluonTS_COVID19_Prediction/GluonTS.API.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.19.0
# jupytext_version: 1.19.1
# kernelspec:
# display_name: Python 3
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -54,25 +54,39 @@
# Once you're comfortable with the mechanics here, move to `GluonTS.example.ipynb` to see these models applied to real COVID-19 data.

# %% [markdown]
# ---
#
# ## Setup

# %%
# %load_ext autoreload
# %autoreload 2

# System libraries.
import logging
import warnings

warnings.filterwarnings("ignore")

# Core GluonTS components for the API tutorial
# Third party libraries.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# %% [markdown]
# ## GluonTS imports and utilities

# %%
from gluonts.evaluation import make_evaluation_predictions
from gluonts.torch.model.deep_npts import DeepNPTSEstimator
from gluonts.torch.model.deepar import DeepAREstimator
from gluonts.torch.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.torch.model.deep_npts import DeepNPTSEstimator
from gluonts.evaluation import make_evaluation_predictions

# All our utilities are in one place - much cleaner!
import GluonTS_utils as gluonts

print("Setup complete. Ready to explore GluonTS.")
# %%
_LOG = logging.getLogger(__name__)
gluonts.init_logger(_LOG)
_LOG.info("Setup complete. Ready to explore GluonTS.")

# %% [markdown]
# ## The GluonTS Workflow
Expand All @@ -90,8 +104,6 @@
# Let's see this in action, starting with the simplest possible time series.

# %% [markdown]
# ---
#
# # Level 1: Sinusoid — The Simplest Pattern
#
# We begin with a pure sine wave plus a small amount of Gaussian noise. This is the easiest pattern a model can encounter: perfectly periodic, no trend, no regime changes.
Expand Down Expand Up @@ -245,8 +257,6 @@
# %% [markdown]
# > **Checkpoint:** You just completed the full GluonTS workflow — configure, train, forecast, visualize, evaluate. This is the same 5-step pattern for *every* GluonTS model. From here on we'll move faster since you know the drill. What changes is the **data** (increasing complexity) and the **model** (different architectures), not the workflow itself.
#
# ---
#
# # Level 2: Multi-Frequency — Adding Realism
#
# Real time series rarely consist of a single clean cycle. This synthetic series combines four components you'll encounter in real data:
Expand Down Expand Up @@ -372,8 +382,6 @@
)

# %% [markdown]
# ---
#
# # Level 3: Regime Change — The Hard Problem
#
# This is where things get interesting. The series behaves one way for the first half, then **abruptly shifts** to a different baseline, amplitude, and frequency.
Expand Down Expand Up @@ -506,15 +514,11 @@
print(f" MAPE: {deepar_regime_metrics['mape']:.2f}%")

# %% [markdown]
# ---
#
# # Model Comparison
#
# Let's bring all results together. Each model was tested on the data type that best highlights its strengths and weaknesses.

# %%
import pandas as pd

comparison = pd.DataFrame(
{
"Model": [
Expand Down Expand Up @@ -551,8 +555,6 @@
print(comparison.to_string(index=False, float_format="%.2f"))

# %% [markdown]
# ---
#
# # Summary
#
# ## What You Learned
Expand All @@ -563,8 +565,6 @@
# 4. **Probabilistic output** — every forecast gives you means, medians, quantiles, and raw samples
# 5. **Model choice matters** — the right model depends on your data's characteristics
#
# ---
#
# ## Which Model Should You Choose?
#
# | If your data has... | Try this model | Why |
Expand All @@ -574,8 +574,6 @@
# | Regime shifts, unusual distributions | **DeepNPTS** | Non-parametric — adapts to changing behavior |
# | No idea yet | **Start with SimpleFeedForward** | Fast to test, then try DeepAR for more accuracy |
#
# ---
#
# ## Quick Reference
#
# | Task | Code |
Expand All @@ -585,8 +583,6 @@
# | 80% confidence interval | `forecast.quantile(0.1)` to `forecast.quantile(0.9)` |
# | Raw sample paths | `forecast.samples` (shape: `num_samples × prediction_length`) |
#
# ---
#
# ## Tips for Better Results
#
# | Area | Tip |
Expand All @@ -596,8 +592,6 @@
# | **Features** | More isn't always better — test with and without |
# | **Data quality** | Handle missing values and normalize if needed |
#
# ---
#
# ## Troubleshooting
#
# | Problem | Fix |
Expand All @@ -607,14 +601,10 @@
# | *Poor forecast quality* | Increase `context_length`, train longer, or try DeepNPTS for regime changes |
# | *"Unexpected keyword argument"* | DeepAR: `trainer_kwargs={"max_epochs": N}`. DeepNPTS: `epochs=N`. SimpleFeedForward: no `freq` parameter |
#
# ---
#
# ## Resources
#
# - [GluonTS Documentation](https://ts.gluon.ai/) · [GitHub](https://github.com/awslabs/gluonts) · [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)
#
# ---
#
# ## What's Next?
#
# Now that you understand how GluonTS works on clean synthetic data, move to **`GluonTS.example.ipynb`** to see these same models applied to real **COVID-19 case prediction** — with feature engineering, multiple covariates, scenario analysis, and all the messiness of real-world data.
Expand Down
45 changes: 21 additions & 24 deletions tutorials/GluonTS_COVID19_Prediction/GluonTS.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.19.0
# jupytext_version: 1.19.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand All @@ -25,21 +25,31 @@
# **Models:** We compare DeepAR (complex patterns), SimpleFeedForward (fast baseline), and DeepNPTS (regime changes).

# %% [markdown]
# ---
#
# ## 1. Setup and Imports
# ## 1. Setup and imports
#
# Let's get everything set up for our COVID-19 forecasting analysis.

# %%
# %load_ext autoreload
# %autoreload 2

# System libraries.
import logging
import warnings

warnings.filterwarnings("ignore")

# All our utilities in one place - much cleaner!
import GluonTS_utils as gluonts
# Third party libraries.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Explicit imports for functions called without gluonts. prefix
# %% [markdown]
# ## GluonTS utilities

# %%
import GluonTS_utils as gluonts
from GluonTS_utils import (
train_deepar_covid,
train_feedforward_covid,
Expand All @@ -51,11 +61,12 @@
print_policy_insights,
)

print("Setup complete. Ready to forecast COVID-19 cases.")
# %%
_LOG = logging.getLogger(__name__)
gluonts.init_logger(_LOG)
_LOG.info("Setup complete. Ready to forecast COVID-19 cases.")

# %% [markdown]
# ---
#
# ## 2. Load and Explore COVID-19 Data
#
# Let's load our real COVID-19 data and take a look at what we're working with.
Expand Down Expand Up @@ -116,8 +127,6 @@
# **Why these features?** Target (7-day MA) smooths reporting; deaths lag cases and correlate with outcomes; CFR indicates strain; mobility captures lockdown effects.

# %% [markdown]
# ---
#
# ## 3. Feature Engineering
#
# Our data pipeline has already engineered several features to improve model performance:
Expand Down Expand Up @@ -148,8 +157,6 @@
)

# %% [markdown]
# ---
#
# ## 4. Train All Three Models
#
# **Model choice:** DeepAR for complex wave patterns; SimpleFeedForward for a fast baseline; DeepNPTS for regime shifts across COVID variants.
Expand Down Expand Up @@ -230,8 +237,6 @@
)

# %% [markdown]
# ---
#
# ## 5. Compare Models
#
# Now that we've trained all three models, let's compare their performance!
Expand Down Expand Up @@ -293,8 +298,6 @@
)

# %% [markdown]
# ---
#
# ## 6. Scenario Analysis: Simulating Interventions
#
# One of the most powerful applications of forecasting is **scenario analysis** -
Expand Down Expand Up @@ -372,8 +375,6 @@
print_policy_insights(scenario_results)

# %% [markdown]
# ---
#
# ## 7. Conclusions and Recommendations
#
# ### Key Takeaways
Expand Down Expand Up @@ -438,8 +439,6 @@
# 4. **Scalability**: Use GPU acceleration for faster training
# 5. **Interpretability**: Provide explanations alongside forecasts
#
# ---
#
# ## Congratulations!
#
# You've completed a full end-to-end COVID-19 forecasting application!
Expand All @@ -455,8 +454,6 @@
# **Ready to apply these skills to your own forecasting problems?**

# %% [markdown]
# ---
#
# ## Additional Resources
#
# **GluonTS Documentation**
Expand Down
27 changes: 27 additions & 0 deletions tutorials/GluonTS_COVID19_Prediction/GluonTS_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@
warnings.filterwarnings("ignore")


def init_logger(notebook_log: logging.Logger) -> None:
"""
Configure notebook display and route loggers to print for Jupyter output.

Standalone tutorial images do not include the helpers package, so this
mirrors the essentials of helpers.hnotebook.config_notebook locally.
"""
import seaborn as sns

plt.rcParams["figure.figsize"] = (12, 6)
plt.rcParams["legend.fontsize"] = 12
plt.rcParams["font.size"] = 12
pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)
sns.set()

def _info_print(msg: str, *args, **kwargs) -> None:
if args:
msg = msg % args
print(msg)

notebook_log.info = _info_print # type: ignore[method-assign]
global _LOG
_LOG.info = _info_print # type: ignore[method-assign]


# #############################################################################
# Analysis
# #############################################################################
Expand Down
Loading