-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathml_utils.py
More file actions
2293 lines (1949 loc) · 85.6 KB
/
Copy pathml_utils.py
File metadata and controls
2293 lines (1949 loc) · 85.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
def _require_keras():
try:
import importlib
tf = importlib.import_module("tensorflow")
# you can use tf.keras.* everywhere; no need to import keras separately
return tf
except Exception as e:
raise ImportError(
"TensorFlow/Keras required for this function. "
"Install with `pip install tensorflow` or `yourpkg[cnn]`."
) from e
import numpy as np
import xarray as xr
import pandas as pd
def time_series_split(
data: xr.Dataset,
num_var,
cat_var=None,
mask="ocean_mask",
split_ratio=(0.7, 0.2, 0.1),
seed=42,
X_mean=None,
X_std=None,
y_var="y",
years=None, # select one year, a list of years, or a slice
cast_float32=True,
contiguous_splits=False,
return_full=False,
nan_max_frac_y=0.5, # what fraction of missing days allowed in y
nan_max_frac_v=0.05, # what fraction of missing days allowed in numerical vars
add_missingness=False,
verbose=False
):
"""
Pure-NumPy splitter/normalizer for xarray Dataset (NumPy-backed).
Splits time indices randomly into train/val/test.
Normalizes numerical variables only, using either provided or training-set mean/std.
Replaces NaNs with 0s.
Removes days with too many NaNs (>
Parameters:
data: xarray dataset with 'time' dimension
years: year(s) to use for training
num_var: list of numerical variable names (to normalize)
cat_var: list of categorical variable names (no normalization)
y_var: name of response variable in data.
mask: name of the mask in the data. 0 = ignore; 1 = use; can be static or one for each time step (y)
split_ratio: tuple (train, val, test), must sum to 1.0
seed: random seed
nan_max_frac_y: maximum percent missing values for response
nan_max_frac_v: maximum percent missing values for explanatory variables
X_mean, X_std: optional mean/std arrays for num_var only (shape = [n_num_vars])
cast_float32 : If True, cast outputs to float32 (good for TF)
verbose: print out info
return_full: return X and y
contiguous_splits: versus random splits
Returns:
X, y: full input and response arrays (NumPy arrays)
X_train, y_train, X_val, y_val, X_test, y_test: split data X_mean, X_std: mean and std used for normalization
If return_full=False, X and y are None.
"""
if cat_var is None:
cat_var = []
input_var = list(num_var) + list(cat_var)
# --- checks
if "time" not in data.dims:
raise ValueError("Dataset must contain a 'time' dimension.")
if abs(sum(split_ratio) - 1.0) > 1e-6:
raise ValueError("split_ratio must sum to 1.0")
if "ocean_mask" not in data:
raise KeyError("Dataset must contain 'ocean_mask' (1=ocean, 0=land).")
# ---------- subset by year(s) ----------
if years is not None:
if isinstance(years, (str, int)):
data = data.sel(time=str(years))
elif isinstance(years, slice):
data = data.sel(time=years)
else:
# assume iterable of years (ints/strs)
ti = pd.DatetimeIndex(np.asarray(data["time"].values))
yrs = set(int(y) for y in years)
sel = xr.DataArray(np.isin(ti.year, list(yrs)), coords={"time": data["time"]}, dims=["time"])
data = data.sel(time=sel)
if data.sizes.get("time", 0) == 0:
raise ValueError("No timesteps left after year filtering.")
# create a template for broadcasting 2D -> 3D
template = data[y_var]
# NaN-based time filtering where mask = 1
ocean = data["ocean_mask"].astype(bool)
if "time" not in ocean.dims:
ocean = ocean.expand_dims({"time": data["time"]}).broadcast_like(template)
else:
ocean = ocean.broadcast_like(template)
spatial_dims = [d for d in ocean.dims if d != "time"]
ocean_pix_per_t = ocean.sum(dim=spatial_dims)
check_vars = input_var + [y_var]
valid_times = xr.DataArray(np.ones(data.sizes["time"], dtype=bool), coords={"time": data["time"]}, dims=["time"])
for v in check_vars:
if v not in data:
raise KeyError(f"Variable '{v}' not found in dataset.")
arr = data[v]
if "time" not in arr.dims:
arr = arr.expand_dims({"time": data["time"]}).broadcast_like(template)
else:
arr = arr.broadcast_like(template)
frac = nan_max_frac_y if v == y_var else nan_max_frac_v
nan_thresh = frac * ocean_pix_per_t # (time,)
v_nan = xr.apply_ufunc(np.isnan, arr) & ocean
v_nan_count = v_nan.sum(dim=spatial_dims)
# Remove days with too many NaNs
valid_times = valid_times & (v_nan_count < nan_thresh)
before = int(data.sizes["time"])
data = data.sel(time=valid_times)
ocean = ocean.sel(time=valid_times)
after = int(data.sizes["time"])
if after == 0:
raise ValueError("No timesteps left after NaN filtering.")
if verbose:
yrs_msg = f" (years={years})" if years is not None else ""
print(f"[NaN filter]{yrs_msg} kept {after}/{before} days "
f"(≤ {nan_max_frac*100:.1f}% NaNs over ocean per variable).")
# --- days-per-month report ---
t = pd.to_datetime(data["time"].values)
# group by year-month (works for one or many years)
per_month = (
pd.Series(1, index=pd.Index(t, name="time"))
.groupby([t.year, t.month])
.sum()
.astype(int)
)
# prettify as "YYYY-MM"
per_month.index = [f"{y:04d}-{m:02d}" for y, m in per_month.index]
print("Days kept per month:")
for ym, cnt in per_month.items():
print(f" {ym}: {cnt}")
# compact 12-month line when only a single year is present
if len(pd.unique(t.year)) == 1:
counts = (
pd.Series(1, index=t)
.groupby(t.month)
.sum()
.reindex(range(1, 13), fill_value=0)
.astype(int)
)
print("By month (Jan..Dec):", " ".join(f"{c:2d}" for c in counts.values))
# ---------- split indices ----------
time_len = data.sizes["time"]
rng = np.random.default_rng(seed)
all_indices = rng.choice(time_len, size=time_len, replace=False)
# Compute indices for splitting data into train, validate, and test
train_end = int(split_ratio[0] * time_len)
val_end = int((split_ratio[0] + split_ratio[1]) * time_len)
train_idx = np.sort(all_indices[:train_end])
val_idx = np.sort(all_indices[train_end:val_end])
test_idx = np.sort(all_indices[val_end:])
# ---------- helpers ----------
def fetch(var):
tmpl = data[y_var] # current (post-filter) template
arr = data[var]
if "time" not in arr.dims:
arr = arr.expand_dims({"time": data["time"]}).broadcast_like(tmpl)
else:
# ensure identical order & coords; avoid resurrecting dropped times
arr = arr.transpose("time", ...).reindex_like(tmpl)
out = arr.values.astype("float32", copy=False) if cast_float32 else arr.values
return out
# stats from training; compute before imputation (with median)
if num_var:
if X_mean is None or X_std is None:
means, stds = [], []
for v in num_var:
a = fetch(v)
a_tr = a[train_idx]
means.append(np.nanmean(a_tr, axis=(0, 1, 2)))
stds.append( np.nanstd( a_tr, axis=(0, 1, 2)))
X_mean = np.asarray(means, dtype="float32" if cast_float32 else a.dtype)
X_std = np.asarray(stds, dtype="float32" if cast_float32 else a.dtype)
X_std_safe = np.where(X_std == 0, 1.0, X_std)
else:
X_mean = np.array([], dtype="float32" if cast_float32 else float)
X_std = np.array([], dtype="float32" if cast_float32 else float)
X_std_safe = X_std
# ---- precompute per-pixel medians for num_var using ONLY training data
ocean_np = ocean.transpose("time","lat","lon").values
medians = []
for v in num_var:
a = fetch(v) # (T, H, W)
a_tr = a[train_idx] # (t, H, W)
oce_tr = ocean_np[train_idx] # (t, H, W) boolean
# Mask: land OR invalid values
masked = np.ma.array(a_tr, mask=(~oce_tr) | (~np.isfinite(a_tr)))
# Per-pixel median across time (returns masked result if all masked)
med_ma = np.ma.median(masked, axis=0) # (H, W) masked array
med = med_ma.filled(np.nan) # fill all-masked pixels with NaN
# Fallback for pixels with no finite ocean values
cnt = np.isfinite(masked.filled(np.nan)).sum(axis=0)
if masked.count() > 0:
global_med = float(np.ma.median(masked))
else:
global_med = 0.0
med = np.where(cnt > 0, med, global_med).astype('float32', copy=False)
medians.append(med)
def build_split(idx):
chans = []
# numeric (normalize, impute NaNs with per-pixel medians; optional missingness)
for k, v in enumerate(num_var):
a = fetch(v) # (T, H, W)
a = a[idx] # (t, H, W)
oce_t = ocean_np[idx] # (t, H, W)
miss = (~np.isfinite(a)) & oce_t
# impute with per-pixel median
a = np.where(miss, medians[k], a)
# normalize
a = (a - X_mean[k]) / X_std_safe[k]
a = np.where(oce_t, a, 0.0) # set values over land to 0
chans.append(a.astype("float32", copy=False))
if add_missingness:
# add a 0/1 channel indicating originally-missing inputs
chans.append(miss.astype("float32", copy=False))
# categorical (just fill NaNs with 0 or a benign default)
for v in cat_var:
a = fetch(v) # (T, H, W)
a = a[idx]
a = np.nan_to_num(a) # okay for categorical/auxiliary
chans.append(a.astype("float32", copy=False))
if not chans:
raise ValueError("No input variables provided.")
# stack channels last → (t, H, W, C)
return np.stack(chans, axis=-1)
# IMPORTANT: keep NaNs in y so your masked loss can ignore cloudy/land pixels!
y_full = data[y_var].transpose("time", ...).values
if cast_float32:
y_full = y_full.astype("float32", copy=False)
def take_y(idx):
y_s = y_full[idx] # DO NOT nan to 0 here
return y_s
# ---------- build splits ----------
X_train = build_split(train_idx); y_train = take_y(train_idx)
X_val = build_split(val_idx); y_val = take_y(val_idx)
X_test = build_split(test_idx); y_test = take_y(test_idx)
if return_full:
X = build_split(slice(0, time_len))
y = take_y(slice(0, time_len))
else:
X = None
y = None
return X, y, X_train, y_train, X_val, y_val, X_test, y_test, X_mean, X_std
# Save and Load fitted model
import json, zipfile, tempfile
from pathlib import Path
def save_cnn_bundle(zip_path, model, X_mean, X_std, meta=None):
"""
Create a single zip containing:
- model.keras
- stats.npz (X_mean, X_std)
- meta.json (optional dict)
"""
tf = _require_keras()
zip_path = Path(zip_path)
zip_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
# 1) Save model in Keras native format
model_path = tmp / "model.keras"
model.save(model_path)
# 2) Save stats
np.savez(tmp / "stats.npz", X_mean=X_mean, X_std=X_std)
# 3) Save meta
(tmp / "meta.json").write_text(json.dumps(meta or {}))
# 4) Zip it up
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
z.write(model_path, arcname="model.keras")
z.write(tmp / "stats.npz", arcname="stats.npz")
z.write(tmp / "meta.json", arcname="meta.json")
return str(zip_path)
def load_cnn_bundle(zip_path, compile=False):
"""
Load a bundle produced by save_inference_zip().
Returns: (model, X_mean, X_std, meta_dict)
"""
tf = _require_keras()
zip_path = Path(zip_path)
with zipfile.ZipFile(zip_path, "r") as z, tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
# Extract all files
z.extract("model.keras", path=tmp)
z.extract("stats.npz", path=tmp)
# meta.json might be missing in older bundles; handle gracefully
meta = {}
if "meta.json" in z.namelist():
z.extract("meta.json", path=tmp)
meta = json.loads((tmp / "meta.json").read_text())
# Load model & stats
model = tf.keras.models.load_model(tmp / "model.keras", compile=compile)
stats = np.load(tmp / "stats.npz")
X_mean, X_std = stats["X_mean"], stats["X_std"]
return model, X_mean, X_std, meta
## PLOTTING
import numpy as np
import matplotlib.pyplot as plt
def predict_and_plot_date(
data_xr,
date, # "YYYY-MM-DD" or np.datetime64
model,
num_var, # list of vars to normalize
cat_var, # list of vars not normalized (e.g., ocean_mask, sin/cos time)
X_mean, X_std, # per-channel stats for num_var (shape [len(num_var)])
y_var="y",
mask_var="ocean_mask",
model_type="cnn", # "cnn" or "tabular"
cast_float32=True,
use_percentiles=False, p_lo=5, p_hi=95,
cmap="viridis"
):
"""
Build one-sample input from dataset for a specific date, predict, and plot True vs Pred.
Works with CNN (map→map) and tabular models (flattened pixels).
"""
# ---- resolve date index
date64 = np.datetime64(str(date))
times = np.asarray(data_xr["time"].values)
idxs = np.where(times == date64)[0]
if idxs.size == 0:
raise ValueError(f"Date {date} not found in dataset time coord.")
t = int(idxs[0])
# ---- helper to fetch a variable as (H,W) for that date; broadcast 2D to 3D if needed
# choose a spatial template (first available among inputs or y)
tmpl_name = (num_var + cat_var + [y_var])[0]
tmpl = data_xr[tmpl_name]
def fetch_2d(varname):
arr = data_xr[varname]
if "time" in arr.dims:
arr_t = arr.isel(time=t)
else:
arr_t = arr
arr_t = arr_t.broadcast_like(tmpl.isel(time=t)) # ensure same H,W
a = arr_t.values
if cast_float32:
a = a.astype("float32", copy=False)
return a # (H,W)
# ---- build channels for this date
num_chans = []
for k, vn in enumerate(num_var):
a = fetch_2d(vn)
if (X_mean is not None) and (X_std is not None):
a = (a - X_mean[k]) / (1.0 if X_std[k] == 0 else X_std[k])
a = np.nan_to_num(a)
num_chans.append(a)
cat_chans = []
for vn in cat_var:
a = fetch_2d(vn)
a = np.nan_to_num(a)
cat_chans.append(a)
if not (num_chans or cat_chans):
raise ValueError("No input variables provided.")
# stack to (H,W,C)
X_map = np.stack(num_chans + cat_chans, axis=-1)
H, W, C = X_map.shape
# ---- ground truth map
y_true = fetch_2d(y_var)
# ---- predict
if model_type == "cnn":
_ = _require_keras() # ensure TF present only for cnn path
y_pred = model.predict(X_map[np.newaxis, ...], verbose=0)[0]
if y_pred.ndim == 3 and y_pred.shape[-1] == 1:
y_pred = y_pred[..., 0]
elif model_type == "tabular":
y_pred = model.predict(X_map.reshape(-1, C)).reshape(H, W)
else:
raise ValueError("model_type must be 'cnn' or 'tabular'.")
# ---- mask land to NaN (mask_var==0 → land)
land = (fetch_2d(mask_var) == 0.0)
y_true = np.where(land, np.nan, y_true)
y_pred = np.where(land, np.nan, y_pred)
# ---- color limits
if use_percentiles:
stack = np.concatenate([y_true[~np.isnan(y_true)], y_pred[~np.isnan(y_pred)]]) if np.isfinite(y_true).any() and np.isfinite(y_pred).any() else np.array([])
vmin, vmax = (np.percentile(stack, p_lo), np.percentile(stack, p_hi)) if stack.size else (None, None)
else:
vmin = np.nanmin([y_true, y_pred]); vmax = np.nanmax([y_true, y_pred])
# --- ensure North is up
lat = np.array(data_xr.lat.values)
flip_lat = lat[0] > lat[-1] # True if lat is descending
if flip_lat:
y_true = np.flipud(y_true)
y_pred = np.flipud(y_pred)
# extent must be (xmin, xmax, ymin, ymax) with increasing y
lon_min, lon_max = float(data_xr.lon.min()), float(data_xr.lon.max())
lat_min, lat_max = float(lat.min()), float(lat.max())
extent = [lon_min, lon_max, lat_min, lat_max]
# ---- plot
fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)
im0 = axes[0].imshow(y_true, origin="lower", extent=extent, vmin=vmin, vmax=vmax, cmap=cmap)
axes[0].set_title(f"True {y_var} — {np.datetime_as_string(date64)}"); axes[0].axis("off")
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
im1 = axes[1].imshow(y_pred, origin="lower", extent=extent, vmin=vmin, vmax=vmax, cmap=cmap)
axes[1].set_title(f"Predicted ({model_type.upper()})"); axes[1].axis("off")
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
plt.show()
return y_true, y_pred
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
def plot_true_vs_predicted_year_multi(
data, year,
models, # list of models
X_mean, X_std, # per-channel stats for num_var only
num_var, cat_var, # lists of variable names
y_var="y",
model_types="cnn", # list like ['cnn','tabular', ...] same length as models
model_names=None, # optional names for column titles
cmap='viridis',
day=1,
use_percentiles=True, p_lo=5, p_hi=95
):
assert len(models) == len(model_types), "models and model_types must have same length"
if model_names is None:
model_names = [f"Model {i+1}" for i in range(len(models))]
ds = data.sel(time=year)
# use day-th day of each month
dates = pd.to_datetime(ds.time.values)
df = pd.DataFrame({'date': dates, 'dom': dates.day, 'y': dates.year, 'm': dates.month})
monthly_dates = pd.DatetimeIndex(
df.groupby(['y','m']).apply(
lambda g: g.loc[g.dom>=day, 'date'].min() if (g.dom>=day).any() else g['date'].max()
).sort_values().values
)
n_months = len(monthly_dates)
lat = ds.lat.values
lon = ds.lon.values
flip_lat = lat[0] > lat[-1]
extent = [lon.min(), lon.max(), lat.min(), lat.max()]
land_mask = (ds["ocean_mask"].values == 0.0)
# helper: fetch a 2D array for var at given date; broadcast if var has no time dim
def fetch_2d(var, date):
arr = ds[var]
arr_t = arr.sel(time=date) if "time" in arr.dims else arr
arr_t = arr_t.broadcast_like(ds[y_var].sel(time=date))
a = arr_t.values.astype("float32", copy=False)
return a
# figure: True + one column per model
ncols = 1 + len(models)
fig, axs = plt.subplots(n_months, ncols, figsize=(3.2*ncols, 2.2*n_months), constrained_layout=True)
if n_months == 1:
axs = np.atleast_2d(axs) # ensure 2D indexing
for i, date in enumerate(monthly_dates):
# Build (H,W,C) input for this date
chans = []
for k, v in enumerate(num_var):
a = fetch_2d(v, date)
if (X_mean is not None) and (X_std is not None):
denom = 1.0 if X_std[k] == 0 else X_std[k]
a = (a - X_mean[k]) / denom
a = np.nan_to_num(a)
chans.append(a)
for v in cat_var:
a = fetch_2d(v, date)
chans.append(np.nan_to_num(a))
X_map = np.stack(chans, axis=-1)
H, W, C = X_map.shape
# Truth
truth = fetch_2d(y_var, date)
# Predict with each model
preds = []
for mdl, mtype in zip(models, model_types):
if mtype == "cnn":
_ = _require_keras() # ensure TF present only for cnn path
yhat = mdl.predict(X_map[np.newaxis, ...], verbose=0)[0]
if yhat.ndim == 3 and yhat.shape[-1] == 1:
yhat = yhat[..., 0]
elif mtype == "tabular":
yhat = mdl.predict(X_map.reshape(-1, C)).reshape(H, W)
else:
raise ValueError("model_type must be 'cnn' or 'tabular'.")
preds.append(yhat)
# Apply ocean mask and optional north-up flip
truth_m = np.where(land_mask, np.nan, truth)
preds_m = [np.where(land_mask, np.nan, p) for p in preds]
if flip_lat:
truth_m = np.flipud(truth_m)
preds_m = [np.flipud(p) for p in preds_m]
# Shared color limits per row
all_maps = [truth_m] + preds_m
if use_percentiles:
stack = np.concatenate([m[np.isfinite(m)] for m in all_maps if np.isfinite(m).any()]) if any(np.isfinite(m).any() for m in all_maps) else np.array([])
vmin, vmax = (np.percentile(stack, p_lo), np.percentile(stack, p_hi)) if stack.size else (None, None)
else:
vmin = np.nanmin(all_maps); vmax = np.nanmax(all_maps)
# True panel
ax = axs[i, 0]
im = ax.imshow(truth_m, origin='lower', extent=extent, vmin=vmin, vmax=vmax, cmap=cmap, aspect='equal')
ax.set_title(f"{date.strftime('%Y-%m-%d')} — True", fontsize=9)
ax.axis('off')
# Prediction panels with metrics
for j, (pmap, name) in enumerate(zip(preds_m, model_names), start=1):
axp = axs[i, j]
im = axp.imshow(pmap, origin='lower', extent=extent, vmin=vmin, vmax=vmax, cmap=cmap, aspect='equal')
axp.axis('off')
# metrics (mask NaNs)
m = np.isfinite(truth_m) & np.isfinite(pmap)
if m.any():
r2 = r2_score(truth_m[m].ravel(), pmap[m].ravel())
rmse = np.sqrt(np.mean((truth_m[m] - pmap[m])**2))
axp.set_title(f"{name}\n$R^2$={r2:.2f}, RMSE={rmse:.2f}", fontsize=9)
else:
axp.set_title(f"{name}\nno valid pixels", fontsize=9)
# one colorbar for the last column
# cax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
# fig.colorbar(im, cax=cax, label=y_var)
plt.show()
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error
from skimage.metrics import structural_similarity as ssim
import calendar
def plot_metric_by_month(
data, years, model, X_mean, X_std, num_var, cat_var,
training_year=None, metric='r2',
y_name='y', mask_var='ocean_mask',
ssim_win_size=None, ssim_sigma=None,
ymin=None, ymax=None,
model_type="cnn",
):
"""
Plot a single monthly performance metric for one model across multiple years.
This matches the batching/branching style used in `plot_4metric_by_month`:
it supports CNN models (batched (H,W,C) → (H,W) predictions) and non-CNN
“tabular/BRT”-style models (per-day (H*W,C) → (H*W) predictions). If
`model_type == "cnn"`, TensorFlow/Keras is required and is imported lazily
via `_require_keras()`.
For each `year`:
1) Select one representative day per month (the first day present in that
month). ── same behavior as your original
2) Build an (H, W, C) feature map by stacking:
- numerical vars in `num_var`, normalized with (`X_mean`, `X_std`)
if provided (0 std → no scaling),
- categorical/aux vars in `cat_var` (no normalization).
3) Predict for that day:
- CNN: `model.predict(X_map[None, ...]) → (1,H,W[,1])`.
- BRT/Tabular: reshape (H,W,C) → (H*W,C), predict, reshape back to (H,W).
4) Mask land where `mask_var == 0.0` and compute the requested metric:
- 'r2' : coefficient of determination over valid pixels
- 'rmse': root mean squared error
- 'mae' : mean absolute error
- 'bias': mean(pred - truth)
- 'ssim': structural similarity (with optional `ssim_win_size` and
Gaussian weighting via `ssim_sigma`). NaNs are filled with
per-image means only for SSIM computation.
5) Plot the monthly metric values with an optional dashed style for the
`training_year`. `ymin`/`ymax` set common y-limits if provided.
Parameters
----------
data : xr.Dataset
Contains `y_name`, `mask_var`, and all variables in `num_var`/`cat_var`,
with a `time` dimension and (lat, lon) grid.
years : sequence[int]
Years to evaluate (e.g., [2019, 2020]).
model : object
- If `model_type == "cnn"`: a tf.keras.Model returning (B,H,W[,1]).
- If `model_type in {"brt","tabular"}`: an estimator with
`.predict(X_2d)` producing length n_samples predictions.
X_mean, X_std : array-like or None
Per-channel stats for numerical variables in `num_var` (len == len(num_var)).
If None, numerical inputs are not normalized.
num_var, cat_var : list[str]
Names of numerical and categorical/aux variables to stack as channels.
training_year : int or None
If provided, that year's line is dashed and labeled "(train)".
metric : {'r2','rmse','mae','bias','ssim'}, default 'r2'
Which metric to plot per month.
y_name : str, default 'y'
Target variable in `data`.
mask_var : str, default 'ocean_mask'
Land/ocean mask; pixels with 0.0 are treated as land and masked.
ssim_win_size : int or None
SSIM window size (must be odd if provided).
ssim_sigma : float or None
If provided, enables Gaussian-weighted SSIM with this sigma.
ymin, ymax : float or None
Common y-limits for the plot. If None, Matplotlib defaults are used.
model_type : {'cnn','brt','tabular'}, default 'cnn'
Selects the prediction pathway.
Returns
-------
None
Displays a Matplotlib figure: month (1–12) on x-axis, the chosen metric
on y-axis, with one line per year.
Notes
-----
- Only one representative day per month is used (first available day).
- For SSIM, NaNs are filled (just for the SSIM call) with each image's mean;
`data_range` is derived from the truth field when possible.
"""
assert metric in ['r2', 'rmse', 'mae', 'bias', 'ssim']
if model_type == "cnn":
_ = _require_keras() # only require TF/Keras for CNN
def fetch_2d(ds, var, date, like_var):
arr = ds[var]
arr_t = arr.sel(time=date) if 'time' in arr.dims else arr
arr_t = arr_t.broadcast_like(ds[like_var].sel(time=date))
# ensure spatial order is (lat, lon)
if tuple(d for d in arr_t.dims if d != 'time') != ('lat','lon'):
arr_t = arr_t.transpose(..., 'lat', 'lon') if 'time' in arr_t.dims else arr_t.transpose('lat','lon')
return arr_t.values.astype('float32', copy=False)
metric_by_year_month = {}
for year in years:
ds = data.sel(time=year)
dates = pd.to_datetime(ds.time.values)
monthly_dates = (
pd.Series(dates).groupby([dates.year, dates.month]).min().sort_values()
)
scores = []
for date in monthly_dates:
# build (H,W,C) input for this date
chans = []
for k, v in enumerate(num_var):
a = fetch_2d(ds, v, date, y_name)
if X_mean is not None and X_std is not None:
denom = 1.0 if X_std[k] == 0 else X_std[k]
a = (a - X_mean[k]) / denom
chans.append(np.nan_to_num(a))
for v in cat_var:
a = fetch_2d(ds, v, date, y_name)
chans.append(np.nan_to_num(a))
X_map = np.stack(chans, axis=-1)
# predict (branch like plot_4metric_by_month)
if model_type == "cnn":
pred = model.predict(X_map[np.newaxis, ...], verbose=0)[0]
if pred.ndim == 3 and pred.shape[-1] == 1:
pred = pred[..., 0]
elif model_type in ("brt", "tabular"):
H, W, C = X_map.shape
pred = model.predict(X_map.reshape(-1, C)).reshape(H, W)
else:
raise ValueError("model_type must be 'cnn' or 'brt'/'tabular'.")
# truth & mask
truth = fetch_2d(ds, y_name, date, y_name)
land = (fetch_2d(ds, mask_var, date, y_name) == 0.0)
pred = np.where(land, np.nan, pred)
truth = np.where(land, np.nan, truth)
# metric
if metric == 'ssim':
# fill NaNs for SSIM computation
t = np.nan_to_num(truth, nan=(np.nanmean(truth) if np.isfinite(truth).any() else 0.0))
p = np.nan_to_num(pred, nan=(np.nanmean(pred) if np.isfinite(pred).any() else 0.0))
# robust data_range
dr = np.nanmax(truth) - np.nanmin(truth)
if not np.isfinite(dr) or dr == 0:
dr = (np.nanmax(t) - np.nanmin(t)) or 1.0
# build kwargs safely (don’t pass sigma=None)
ssim_kwargs = {"data_range": dr}
if ssim_win_size is not None:
ssim_kwargs["win_size"] = int(ssim_win_size) # must be odd
if ssim_sigma is not None:
ssim_kwargs["gaussian_weights"] = True
ssim_kwargs["sigma"] = float(ssim_sigma)
score = ssim(t.astype(np.float64), p.astype(np.float64), **ssim_kwargs)
else:
m = ~np.isnan(truth) & ~np.isnan(pred)
if not m.any():
score = np.nan
elif metric == 'r2':
score = r2_score(truth[m].ravel(), pred[m].ravel())
elif metric == 'rmse':
score = float(np.sqrt(np.mean((truth[m] - pred[m])**2)))
elif metric == 'mae':
score = float(mean_absolute_error(truth[m], pred[m]))
elif metric == 'bias':
score = float(np.mean(pred[m] - truth[m]))
scores.append(score)
metric_by_year_month[year] = (monthly_dates.dt.month.values, scores)
# plot
plt.figure(figsize=(10,5))
for year, (months, scores) in metric_by_year_month.items():
label = f"{year} (train)" if year == training_year else year
style = "--" if year == training_year else "-"
plt.plot(months, scores, style, marker='o', label=label)
plt.xlabel("Month")
plt.ylabel({'r2':"$R^2$",'rmse':"RMSE",'mae':"MAE",'bias':"Bias",'ssim':"SSIM"}[metric])
plt.title(f"Monthly {metric.upper()} by Year")
plt.xticks(np.arange(1,13), calendar.month_abbr[1:13])
plt.legend(); plt.grid(True); plt.tight_layout()
if ymin is not None or ymax is not None:
plt.ylim(ymin, ymax)
plt.show()
def plot_4metric_by_month(
data, years, model, X_mean, X_std, num_var, cat_var,
training_year=None,
y_name='y', mask_var='ocean_mask',
ssim_win_size=None, ssim_sigma=None,
ymin=None, ymax=None,
model_type="cnn",
):
"""
Compute and plot monthly performance metrics (R², bias, MAE, SSIM) over
selected days in each month (1, 7, 14, 28) for either a CNN or a BRT model.
For each year in ``years``:
1. Select timesteps from ``data`` within that year.
2. Group timesteps by (year, month).
3. Within each month, select dates whose day-of-month is in {1, 7, 14, 28}.
4. For each selected date, build an (H, W, C) feature map using numerical
variables (``num_var``) and categorical variables (``cat_var``). If
``X_mean``/``X_std`` are provided, numerical variables are normalized.
5. Run the model to obtain predictions for each (H, W) field:
- If ``model_type == "cnn"``:
* Stack daily inputs into (B, H, W, C) and call
``model.predict(X_batch)``.
* Output is expected as (B, H, W) or (B, H, W, 1).
- If ``model_type == "brt"``:
* For each day, reshape (H, W, C) → (H*W, C),
call ``model.predict(X_flat)``, then reshape predictions
back to (H, W).
6. For each day, mask land using ``mask_var``, drop NaNs and compute:
* R²
* Bias (pred - truth)
* MAE
* SSIM (with optional window size and Gaussian sigma)
7. Average daily metrics within each month (via ``np.nanmean``) to obtain
a monthly value.
8. Produce a 2×2 panel plot of monthly metrics across all years, with an
optional highlight for ``training_year``.
Months that do not contain any of the target days {1, 7, 14, 28} are skipped.
Parameters
----------
data : xarray.Dataset
Dataset containing the target and predictor variables.
Must have a ``time`` dimension and at least the variables in
``num_var``, ``cat_var``, ``y_name`` and ``mask_var``.
years : sequence of int
Years to evaluate (used with ``data.sel(time=year)``).
model : object
- If ``model_type == "cnn"``: a tf.keras.Model (or compatible) that
accepts (B, H, W, C) and returns (B, H, W) or (B, H, W, 1).
- If ``model_type == "brt"``: a scikit-learn-like regressor with
``predict(X_2d)`` where X_2d has shape (n_samples, n_features).
X_mean : array-like of float or None
Per-channel means for numerical variables in ``num_var``.
Length must match ``len(num_var)`` if not None. If None, no mean/std
normalization is applied.
X_std : array-like of float or None
Per-channel standard deviations for numerical variables in ``num_var``.
Length must match ``len(num_var)`` if not None. If a std is zero, that
channel is left unscaled. If None, no mean/std normalization is applied.
num_var : list of str
Names of numerical predictor variables in ``data``. Each is fetched,
broadcast to the target grid, optionally normalized, and stacked as a
channel.
cat_var : list of str
Names of categorical / non-normalized predictor variables in ``data``.
Each is fetched, broadcast to the target grid, and stacked as a channel
(no mean/std normalization).
training_year : int, optional
Year used for training. If provided, that year's line is dashed and
labeled "(train)".
y_name : str, default "y"
Name of the target variable in ``data``.
mask_var : str, default "ocean_mask"
Name of the land/ocean mask variable. Values equal to 0.0 are treated
as land and masked out.
ssim_win_size : int, optional
Window size for SSIM. Must be odd if provided.
ssim_sigma : float, optional
Standard deviation for Gaussian-weighted SSIM. If provided, Gaussian
weights are used.
ymin, ymax : float, optional
Common y-limits for all metric subplots. If None, matplotlib defaults.
model_type : {"cnn", "brt"}, default "cnn"
Type of model:
- "cnn": use batched (B, H, W, C) predictions.
- "brt": predict on flattened features per day and reshape back.
Notes
-----
- SSIM is computed on fields where land is masked and NaNs are filled with
the mean of valid values for that day (or 0 if none).
- Daily metric values within a month are averaged via ``np.nanmean``.
- The resulting figure shows four panels (R², Bias, MAE, SSIM) with month
on the x-axis (1–12) and one line per year.
"""
# helper
def fetch_2d(ds, var, date, like_var):
arr = ds[var]
arr_t = arr.sel(time=date) if 'time' in arr.dims else arr
arr_t = arr_t.broadcast_like(ds[like_var].sel(time=date))
# ensure spatial order is (lat, lon)
if tuple(d for d in arr_t.dims if d != 'time') != ('lat', 'lon'):
arr_t = (
arr_t.transpose(..., 'lat', 'lon')
if 'time' in arr_t.dims
else arr_t.transpose('lat', 'lon')
)
return arr_t.values.astype('float32', copy=False)
# target days within each month
target_days = {1, 7, 14, 28}
metrics = ['r2', 'bias', 'mae', 'ssim']
metric_by_year_month = {m: {} for m in metrics}
for year in years:
ds = data.sel(time=year)
dates = pd.to_datetime(ds.time.values)
gb = pd.Series(dates).groupby([dates.year, dates.month])
scores_dict = {m: [] for m in metrics}
months_list = []
for (yy, mm), group in gb:
month_dates = list(group)
pick = [d for d in month_dates if d.day in target_days]
if len(pick) == 0:
continue
chans_list = []
truths = []
lands = []
pred_list = []
for date in pick:
chans = []
for k, v in enumerate(num_var):
a = fetch_2d(ds, v, date, y_name)
if X_mean is not None and X_std is not None:
denom = 1.0 if X_std[k] == 0 else X_std[k]
a = (a - X_mean[k]) / denom
chans.append(np.nan_to_num(a))
for v in cat_var:
a = fetch_2d(ds, v, date, y_name)
chans.append(np.nan_to_num(a))
X_map = np.stack(chans, axis=-1) # (H, W, C)
truths.append(fetch_2d(ds, y_name, date, y_name))
lands.append(fetch_2d(ds, mask_var, date, y_name) == 0.0)
if model_type == "cnn":
chans_list.append(X_map)
elif model_type == "brt":
H, W, C = X_map.shape
X_flat = X_map.reshape(-1, C)
pred_flat = model.predict(X_flat)
pred_list.append(pred_flat.reshape(H, W))
else:
raise ValueError(f"Unknown model_type: {model_type!r}")
# ---- predict batch
if model_type == "cnn":
_ = _require_keras() # ensure TF present only for cnn path
X_batch = np.stack(chans_list, axis=0) # (B, H, W, C)
pred_batch = model.predict(X_batch, verbose=0)
if pred_batch.ndim == 4 and pred_batch.shape[-1] == 1:
pred_batch = pred_batch[..., 0] # (B, H, W)
else: # brt
pred_batch = np.stack(pred_list, axis=0) # (B, H, W)
# ---- metrics
r2_vals, bias_vals, mae_vals, ssim_vals = [], [], [], []
for b, date in enumerate(pick):
truth = np.where(lands[b], np.nan, truths[b])
pred = np.where(lands[b], np.nan, pred_batch[b])
m = ~np.isnan(truth) & ~np.isnan(pred)
if m.any():
r2_vals.append(r2_score(truth[m].ravel(), pred[m].ravel()))
mae_vals.append(float(mean_absolute_error(truth[m], pred[m])))
bias_vals.append(float(np.mean(pred[m] - truth[m])))
else:
r2_vals.append(np.nan)
mae_vals.append(np.nan)
bias_vals.append(np.nan)
t = np.nan_to_num(
truth,
nan=(np.nanmean(truth) if np.isfinite(truth).any() else 0.0),
)
p = np.nan_to_num(
pred,
nan=(np.nanmean(pred) if np.isfinite(pred).any() else 0.0),
)
dr = np.nanmax(truth) - np.nanmin(truth)
if not np.isfinite(dr) or dr == 0:
dr = (np.nanmax(t) - np.nanmin(t)) or 1.0
ssim_kwargs = {"data_range": dr}
if ssim_win_size is not None:
ssim_kwargs["win_size"] = int(ssim_win_size)
if ssim_sigma is not None:
ssim_kwargs["gaussian_weights"] = True
ssim_kwargs["sigma"] = float(ssim_sigma)
ssim_vals.append(
ssim(t.astype(np.float64), p.astype(np.float64), **ssim_kwargs)
)