This repository is the wooheon's reproducing of StemGNN(NeurIPS20)
pip install --upgrade pip
pip install -r requirements.txt
https://github.com/microsoft/StemGNN
python main.py --dataset <name of csv file> --window_size <length of sliding window> --horizon <predict horizon> --norm_method z_score --batch_size 64 --train_length 7 --validate_length 2 --test_length 1
The detailed descriptions about the parameters are as following: New parameter are bold type.
| Parameter name | Description of parameter |
|---|---|
| dataset | file name of input csv |
| window_size | length of sliding window, default 12 |
| horizon | predict horizon, default 3 |
| train_length | length of training data, default 7 |
| validate_length | length of validation data, default 2 |
| test_length | length of testing data, default 1 |
| epoch | epoch size during training |
| optimizer | optimizer, default RMSProp |
| lr | learning rate, default 1e-3 |
| decay_rate | decay rate, default 0.7 |
| exponential_decay_step | exponential decay step, default 5 |
| randomwalk_laplacian | determine whether to use randomwalk normalized laplacian matrix |
| attention_channel | hyper parameter of latent correlation layer, default 32 |
| kernel_size | hyper parameter of Gated CNN's kernel size, default 3 |
| gcnn_channel | hyper parameter of Gated CNN's channel, default 32 |
| gconv_channel | hyper parameter of Graph Convolution channel, default 64 |
| multi_channel | hyper parameter of StemBlock's forecast, backcast output channel, default 128 |
| device | device that the code works on, 'cpu' or 'cuda:x' |
| validate_freq | frequency of validation |
| batch_size | batch size, default 64 |
| dropout_rate | dropout_rate, default 0.2 |
| leakyrelu_rate | leakyrelu rate, default 0.5 |
| norm_method | method for normalization, 'z_score' or 'min_max' |
| early_stop | whether to enable early stop, default False |
My reproducing model shows following performance on the 10 datasets:
Table 1 Configuration and perforamance for all datasets
| Dataset | window_size | horizon | norm_method | MAE | RMSE | MAPE(%) |
|---|---|---|---|---|---|---|
| METR-LA | 12 | 3 | z_score | |||
| PEMS-BAY | 12 | 3 | z_score | |||
| PEMS03 | 12 | 3 | z_score | |||
| PEMS04 | 12 | 3 | z_score | |||
| PEMS07 | 12 | 3 | z_score | |||
| PEMS08 | 12 | 3 | z_score | |||
| COVID-19 | 28 | 28 | z_score |