-
Notifications
You must be signed in to change notification settings - Fork 21
Description
I have noticed 2 potential errors in the notebook CH02
- You do the below in your nb
zero_shot_preds = zero_shot_preds.rename(columns={"NBEATS": "NBEATS_pretrained"})
finetuned_preds = finetuned_preds.rename(columns={"NBEATS": "NBEATS_finetuned"})
trained_preds = trained_preds.rename(columns={"NBEATS": "NBEATS_trained"})
test_df = pd.merge(test_df, zero_shot_preds, 'left', 'ds')
test_df = pd.merge(test_df, finetuned_preds, 'left', 'ds')
test_df = pd.merge(test_df, trained_preds, 'left', 'ds')
test_df.head()
when doing this, this will throw an error because all 3 df have unique_id as a column. I'd do something like this
keys = ["unique_id", "ds"]
test_df = test_df.merge(zero_shot_preds[keys + ["NBEATS_pretrained"]], on=keys, how="left")
test_df = test_df.merge(finetuned_preds[keys + ["NBEATS_finetuned"]], on=keys, how="left")
test_df = test_df.merge(trained_preds[keys + ["NBEATS_trained"]], on=keys, how="left")
- In the section - Forecasting another frequency
you do
pretrained_model = NeuralForecast.load(path='./model')
d_zero_shot_preds = pretrained_model.predict(d_input_df)
and you can see in the head of the df that the predictions are on a monthly basis
| unique_id | ||
|---|---|---|
| 1981-12-31 | 15.576023 | |
| 1982-01-31 | 16.021452 | |
| 1982-02-28 | 16.849257 | |
| 1982-03-31 | 17.739635 | |
| 1982-04-30 | 17.888920 |
but then you merge it with predictions of daily timesteps that you did here
models = [NBEATS(input_size=2*horizon, h=horizon, max_steps=500)]
nf = NeuralForecast(models=models, freq='D')
nf.fit(df=d_input_df)
d_trained_preds = nf.predict()
d_zero_shot_preds = d_zero_shot_preds.rename(columns={"NBEATS": "NBEATS_zero_shot"})
d_zero_shot_preds = d_zero_shot_preds.reset_index(drop=True)
d_trained_preds = d_trained_preds.rename(columns={"NBEATS": "NBEATS_trained"})
d_test_df = pd.merge(d_test_df, d_trained_preds, 'left', 'ds')
d_test_df = pd.concat([d_test_df, d_zero_shot_preds['NBEATS_zero_shot']], axis=1)
It shouldnt work and it will throw an error when trying to run the code.