From 1a42126609637a4bc25aff0e2ddeaf9f8be02199 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Thu, 14 Sep 2023 09:56:06 +0100 Subject: [PATCH 01/21] update from sueda --- stdata/vis/spacetime.py | 48 ++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index edcff6a..5cadf0e 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -1,13 +1,11 @@ import numpy as np -import matplotlib import matplotlib.pyplot as plt from matplotlib.widgets import Slider from mpl_toolkits.axes_grid1 import make_axes_locatable from datetime import datetime -from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon import geopandas -import shapely +import shapely.geometry as sg def plot_polygon_collection( @@ -20,28 +18,24 @@ def plot_polygon_collection( edgecolor=None, alpha=1.0, linewidth=1.0, - **kwargs + **kwargs, ): - """ Plot a collection of Polygon geometries """ + """Plot a collection of Polygon geometries""" patches = [] for poly in geoms: - #a = np.asarray(poly.geoms[0].exterior) a = np.asarray(poly.exterior) - if poly.has_z: - poly = shapely.geometry.Polygon(zip(*poly.geoms[0].exterior.xy)) - patches.append(Polygon(a)) - patches = PatchCollection( + patches = plt.PatchCollection( patches, facecolor=facecolor, linewidth=linewidth, edgecolor=edgecolor, alpha=alpha, norm=norm, - **kwargs + **kwargs, ) if values is not None: @@ -99,7 +93,6 @@ def get_data(self, epoch): if x_train is None: return None, None - print(z_train.shape) z_train = np.array(z_train) s = np.c_[x_train, y_train] @@ -116,7 +109,7 @@ def setup(self): if self.norm_on_training: df = self.train_df - self.norm = matplotlib.colors.Normalize( + self.norm = plt.Normalize( vmin=np.min(df[self.columns[self.col]]), vmax=np.max(df[self.columns[self.col]]), ) @@ -131,7 +124,6 @@ def setup(self): def update(self, epoch): if self.geopandas_flag: - # If grid_plot is init with zero patches then we need to create them if self.grid_plot is None: self.plot(epoch) @@ -242,12 +234,9 @@ def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_fl def setup(self): pass - def get_time_series(self, _id, data): + def get_time_series(self, _id, data): d = self.train_df[self.train_df[self.columns["id"]] == _id] - print(f'Plotting timeseries: {_id}') - - d = d.sort_values(by=self.columns["epoch"]) epochs = d[self.columns["epoch"]].astype(np.float32) @@ -260,7 +249,9 @@ def get_time_series(self, _id, data): def plot(self, _id): epochs, var, pred, observed = self.get_time_series(_id, self.train_df) - self.var_plot = self.ax.fill_between(epochs, pred - 1.96*np.sqrt(var), pred + 1.96*np.sqrt(var)) + self.var_plot = self.ax.fill_between( + epochs, pred - 1.96 * np.sqrt(var), pred + 1.96 * np.sqrt(var) + ) self.observed_scatter = self.ax.scatter(epochs, observed) self.pred_plot = self.ax.plot(epochs, pred) self.ax.set_xlim([self.min_test_epoch, self.max_test_epoch]) @@ -298,9 +289,7 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ if grid_plot_flag: self.norm = grid_plot.norm else: - print("min: ", np.min(self.train_df[self.columns["pred"]])) - print("max: ", np.max(self.train_df[self.columns["pred"]])) - self.norm = matplotlib.colors.Normalize( + self.norm = plt.Normalize( vmin=np.min(self.train_df[self.columns["pred"]]), vmax=1000 ) @@ -317,7 +306,6 @@ def get_closest_observed(self, p): ) dists = np.sum((d - p) ** 2, axis=1) i = np.argmin(dists) - #if dists[i] <= 1e-4: if dists[i] <= 0.02: return self.train_df.iloc[i][self.columns["id"]] else: @@ -368,7 +356,9 @@ def update_active(self, _id): class SpaceTimeVisualise(object): - def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_start=None): + def __init__( + self, train_df, test_df, sat_df=None, geopandas_flag=True, test_start=None + ): columns = { "id": "id", "epoch": "epoch", @@ -388,7 +378,7 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta self.grid_plot_flag = not (self.test_df is None) - self.min_time = np.min(self.train_df[columns["epoch"]]) + s self.min_time = np.min(self.train_df[columns["epoch"]]) self.max_time = np.max(self.train_df[columns["epoch"]]) if test_start: @@ -418,7 +408,7 @@ def update_epoch(self, epoch): def show(self): self.fig = plt.figure(figsize=(12, 6)) - self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) + self.gs = plt.GridSpec(12, 4, wspace=0.25, hspace=0.25) self.grid_plot_1_ax = self.fig.add_subplot( self.gs[0:7, 0:2] ) # first row, first col @@ -492,7 +482,6 @@ def show(self): ) self.time_series_plot.setup() - if self.grid_plot_flag: self.val_grid_plot.plot(self.start_epoch) self.var_grid_plot.plot(self.start_epoch) @@ -505,8 +494,9 @@ def show(self): self.val_scatter_plot.plot_active(self.start_id) if self.sat_df is not None: - self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) - + self.time_series_plot.ax.scatter( + self.sat_df["epoch"], self.sat_df[self.columns["observed"]], alpha=0.4 + ) plt.show() From ef922fa4235f614e8ed0d2b90377a5d39df8a828 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Thu, 14 Sep 2023 10:03:10 +0100 Subject: [PATCH 02/21] revert --- stdata/vis/spacetime.py | 51 +++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 5cadf0e..bf522d1 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -1,11 +1,13 @@ import numpy as np +import matplotlib import matplotlib.pyplot as plt from matplotlib.widgets import Slider from mpl_toolkits.axes_grid1 import make_axes_locatable from datetime import datetime +from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon import geopandas -import shapely.geometry as sg +import shapely def plot_polygon_collection( @@ -18,24 +20,29 @@ def plot_polygon_collection( edgecolor=None, alpha=1.0, linewidth=1.0, - **kwargs, + **kwargs ): - """Plot a collection of Polygon geometries""" + """ Plot a collection of Polygon geometries """ patches = [] for poly in geoms: - a = np.asarray(poly.exterior) + #a = np.asarray(poly.geoms[0].exterior) + #a = np.asarray(poly.exterior) + a = np.asarray(poly.exterior.xy).T + + if poly.has_z: + poly = shapely.geometry.Polygon(zip(*poly.geoms[0].exterior.xy)) patches.append(Polygon(a)) - patches = plt.PatchCollection( + patches = PatchCollection( patches, facecolor=facecolor, linewidth=linewidth, edgecolor=edgecolor, alpha=alpha, norm=norm, - **kwargs, + **kwargs ) if values is not None: @@ -93,6 +100,7 @@ def get_data(self, epoch): if x_train is None: return None, None + print(z_train.shape) z_train = np.array(z_train) s = np.c_[x_train, y_train] @@ -109,7 +117,7 @@ def setup(self): if self.norm_on_training: df = self.train_df - self.norm = plt.Normalize( + self.norm = matplotlib.colors.Normalize( vmin=np.min(df[self.columns[self.col]]), vmax=np.max(df[self.columns[self.col]]), ) @@ -124,6 +132,7 @@ def setup(self): def update(self, epoch): if self.geopandas_flag: + # If grid_plot is init with zero patches then we need to create them if self.grid_plot is None: self.plot(epoch) @@ -234,9 +243,12 @@ def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_fl def setup(self): pass - def get_time_series(self, _id, data): + def get_time_series(self, _id, data): d = self.train_df[self.train_df[self.columns["id"]] == _id] + print(f'Plotting timeseries: {_id}') + + d = d.sort_values(by=self.columns["epoch"]) epochs = d[self.columns["epoch"]].astype(np.float32) @@ -249,9 +261,7 @@ def get_time_series(self, _id, data): def plot(self, _id): epochs, var, pred, observed = self.get_time_series(_id, self.train_df) - self.var_plot = self.ax.fill_between( - epochs, pred - 1.96 * np.sqrt(var), pred + 1.96 * np.sqrt(var) - ) + self.var_plot = self.ax.fill_between(epochs, pred - 1.96*np.sqrt(var), pred + 1.96*np.sqrt(var)) self.observed_scatter = self.ax.scatter(epochs, observed) self.pred_plot = self.ax.plot(epochs, pred) self.ax.set_xlim([self.min_test_epoch, self.max_test_epoch]) @@ -289,7 +299,9 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ if grid_plot_flag: self.norm = grid_plot.norm else: - self.norm = plt.Normalize( + print("min: ", np.min(self.train_df[self.columns["pred"]])) + print("max: ", np.max(self.train_df[self.columns["pred"]])) + self.norm = matplotlib.colors.Normalize( vmin=np.min(self.train_df[self.columns["pred"]]), vmax=1000 ) @@ -306,6 +318,7 @@ def get_closest_observed(self, p): ) dists = np.sum((d - p) ** 2, axis=1) i = np.argmin(dists) + #if dists[i] <= 1e-4: if dists[i] <= 0.02: return self.train_df.iloc[i][self.columns["id"]] else: @@ -356,9 +369,7 @@ def update_active(self, _id): class SpaceTimeVisualise(object): - def __init__( - self, train_df, test_df, sat_df=None, geopandas_flag=True, test_start=None - ): + def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_start=None): columns = { "id": "id", "epoch": "epoch", @@ -378,7 +389,7 @@ def __init__( self.grid_plot_flag = not (self.test_df is None) - s self.min_time = np.min(self.train_df[columns["epoch"]]) + self.min_time = np.min(self.train_df[columns["epoch"]]) self.max_time = np.max(self.train_df[columns["epoch"]]) if test_start: @@ -408,7 +419,7 @@ def update_epoch(self, epoch): def show(self): self.fig = plt.figure(figsize=(12, 6)) - self.gs = plt.GridSpec(12, 4, wspace=0.25, hspace=0.25) + self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) self.grid_plot_1_ax = self.fig.add_subplot( self.gs[0:7, 0:2] ) # first row, first col @@ -482,6 +493,7 @@ def show(self): ) self.time_series_plot.setup() + if self.grid_plot_flag: self.val_grid_plot.plot(self.start_epoch) self.var_grid_plot.plot(self.start_epoch) @@ -494,9 +506,8 @@ def show(self): self.val_scatter_plot.plot_active(self.start_id) if self.sat_df is not None: - self.time_series_plot.ax.scatter( - self.sat_df["epoch"], self.sat_df[self.columns["observed"]], alpha=0.4 - ) + self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) + plt.show() From 5e32de2cbf648923f12c29da4b0aeb1521c4bd41 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Thu, 14 Sep 2023 10:09:37 +0100 Subject: [PATCH 03/21] fix removes --- stdata/vis/spacetime.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index bf522d1..6148c6c 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -278,11 +278,16 @@ def update_cur_epoch(self, epoch): self.plot_cur_epoch(epoch) def update(self, _id): - self.var_plot.remove() - self.observed_scatter.remove() - self.ax.lines.remove(self.pred_plot[0]) - self.min_line.remove() - self.max_line.remove() + try: + self.var_plot.remove() + self.observed_scatter.remove() + self.pred_plot[0].remove() + self.min_line.remove() + self.max_line.remove() + except ValueError as e: + # already been removed so need to remove again + pass + self.plot(_id) From 4dc5fc20a4edf2556e4a4216e9ff1e04b149d558 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 12:40:51 +0100 Subject: [PATCH 04/21] spacetime matplot updates --- stdata/vis/spacetime.py | 113 +++++++++++++++------------------------- 1 file changed, 43 insertions(+), 70 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 6148c6c..00c5630 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -1,11 +1,14 @@ import numpy as np import matplotlib import matplotlib.pyplot as plt +import matplotlib.widgets as widgets from matplotlib.widgets import Slider from mpl_toolkits.axes_grid1 import make_axes_locatable from datetime import datetime from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon +import matplotlib.colors +import matplotlib.patches as mpatches import geopandas import shapely @@ -13,9 +16,8 @@ def plot_polygon_collection( ax, geoms, - norm, + cmap="viridis", values=None, - colormap="Set1", facecolor=None, edgecolor=None, alpha=1.0, @@ -26,9 +28,7 @@ def plot_polygon_collection( patches = [] for poly in geoms: - #a = np.asarray(poly.geoms[0].exterior) - #a = np.asarray(poly.exterior) - a = np.asarray(poly.exterior.xy).T + a = np.asarray(poly.exterior.xy) if poly.has_z: poly = shapely.geometry.Polygon(zip(*poly.geoms[0].exterior.xy)) @@ -41,13 +41,12 @@ def plot_polygon_collection( linewidth=linewidth, edgecolor=edgecolor, alpha=alpha, - norm=norm, + cmap=cmap, **kwargs ) if values is not None: patches.set_array(values) - patches.set_cmap(colormap) ax.add_collection(patches, autolim=True) ax.autoscale_view() @@ -127,7 +126,6 @@ def setup(self): dir_str = "left" if self.right_flag: dir_str = "right" - self.color_bar_ax = self.divider.append_axes(dir_str, size="5%", pad=0.05) def update(self, epoch): @@ -167,9 +165,16 @@ def plot(self, epoch): return df = df.sort_values(self.columns["id"]) - geo_series = geopandas.GeoSeries(df["geom"]) - self.grid_plot = plot_polygon_collection(self.ax, geo_series, self.norm) - self.grid_plot.set_array(df[self.columns[self.col]]) + geoms = df["geom"] + self.grid_plot = plot_polygon_collection( + self.ax, + geoms, + cmap=self.cmap, + norm=self.norm, + shading="auto", + edgecolor="white", + linewidths=0.2, + ) else: s, z_train = self.get_data(epoch) if z_train is None: @@ -183,14 +188,14 @@ def plot(self, epoch): self.grid_plot = self.ax.imshow( z_train, - origin="lower", + origin="lowerleft", cmap=self.cmap, norm=self.norm, aspect="auto", extent=[min_x, max_x, min_y, max_y], ) self.fig.colorbar( - self.grid_plot, cax=self.color_bar_ax, orientation="vertical" + self.grid_plot, cax=self.color_bar_ax, orientation="horizontal" ) return self.grid_plot @@ -203,32 +208,25 @@ def __init__(self, fig, ax, unique_vals, callback): self.unique_vals = unique_vals self.callback = callback - def set_text_format(self): - datetime.fromtimestamp(1472860800).strftime("%Y-%m-%d %H") - self.slider.valtext.set_text( - datetime.fromtimestamp(self.slider.val).strftime("%Y-%m-%d %H") - ) - - def setup(self, start_val): - self.slider = Slider( + self.slider = widgets.Slider( self.ax, "Date", np.min(self.unique_vals), np.max(self.unique_vals), - valinit=start_val, + valinit=unique_vals[0], ) - self.set_text_format() self.slider.on_changed(self.update) def update(self, i): cur_epoch_i = np.abs(self.unique_vals - i).argmin() cur_epoch = self.unique_vals[cur_epoch_i] - self.set_text_format() self.callback(cur_epoch) class ST_TimeSeriesPlot(object): - def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_flag): + def __init__( + self, columns, fig, ax, train_df, test_df, test_start, grid_plot_flag + ): self.columns = columns self.fig = fig self.ax = ax @@ -240,6 +238,8 @@ def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_fl self.max_test_epoch = np.max(self.train_df[columns["epoch"]]) self.test_start_epoch = test_start or self.min_test_epoch + self.slider = ST_SliderPlot(fig, ax, self.train_df["epoch"], self.update_cur_epoch) + def setup(self): pass @@ -270,26 +270,12 @@ def plot(self, _id): self.max_line = self.ax.axvline(self.max_test_epoch) self.test_start_line = self.ax.axvline(self.test_start_epoch) - def plot_cur_epoch(self, epoch): - self.cur_epoch_line = self.ax.axvline(epoch, ymin=0.25, ymax=1.0) - - def update_cur_epoch(self, epoch): - self.cur_epoch_line.remove() - self.plot_cur_epoch(epoch) + self.slider.update(self.train_df[self.columns["epoch"]][0]) def update(self, _id): - try: - self.var_plot.remove() - self.observed_scatter.remove() - self.pred_plot[0].remove() - self.min_line.remove() - self.max_line.remove() - except ValueError as e: - # already been removed so need to remove again - pass - + cur_epoch = self.slider.val self.plot(_id) - + self.update_cur_epoch(cur_epoch) class ST_ScatterPlot(object): def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_df): @@ -304,15 +290,14 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ if grid_plot_flag: self.norm = grid_plot.norm else: - print("min: ", np.min(self.train_df[self.columns["pred"]])) - print("max: ", np.max(self.train_df[self.columns["pred"]])) self.norm = matplotlib.colors.Normalize( - vmin=np.min(self.train_df[self.columns["pred"]]), vmax=1000 + vmin=np.min(self.train_df[self.columns["pred"]]), vmax=np.max(self.train_df[self.columns["pred"]]) ) self.callback = callback self.cur_epoch = None + self.cur_id = None def setup(self): self.fig.canvas.mpl_connect("button_release_event", self.on_plot_hover) @@ -323,7 +308,7 @@ def get_closest_observed(self, p): ) dists = np.sum((d - p) ** 2, axis=1) i = np.argmin(dists) - #if dists[i] <= 1e-4: + # if dists[i] <= 1e-4: if dists[i] <= 0.02: return self.train_df.iloc[i][self.columns["id"]] else: @@ -407,7 +392,6 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta self.start_epoch = self.unique_epochs[-1] self.start_id = self.unique_ids[0] - def update_timeseries(self, _id): self.time_series_plot.update(_id) self.val_scatter_plot.update_active(_id) @@ -420,25 +404,18 @@ def update_epoch(self, epoch): self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) - + def show(self): - self.fig = plt.figure(figsize=(12, 6)) - - self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) - self.grid_plot_1_ax = self.fig.add_subplot( - self.gs[0:7, 0:2] - ) # first row, first col - self.grid_plot_2_ax = self.fig.add_subplot( - self.gs[0:7, 2:4] - ) # first row, second col - self.epoch_slider_ax = self.fig.add_subplot( - self.gs[7, 1:3] - ) # first row, second col - self.time_series_ax = self.fig.add_subplot(self.gs[8:11, :]) # full second row - self.scale_slider_ax = self.fig.add_subplot( - self.gs[11, 1:3] - ) # first row, second col + self.fig, self.axs = plt.subplots( + 3, 2, figsize=(12, 6), gridspec_kw={"width_ratios": [1, 0.2], "height_ratios": [1, 1]} + ) + self.gs = self.fig.add_gridspec(12, 4, wspace=0.25, hspace=0.25) + self.grid_plot_1_ax = self.axs[0, 0] + self.grid_plot_2_ax = self.axs[0, 1] + self.epoch_slider_ax = self.axs[1, 0] + self.time_series_ax = self.axs[2, 0] + self.scale_slider_ax = self.axs[2, 1] if self.grid_plot_flag: self.val_grid_plot = ST_GridPlot( self.columns, @@ -498,7 +475,6 @@ def show(self): ) self.time_series_plot.setup() - if self.grid_plot_flag: self.val_grid_plot.plot(self.start_epoch) self.var_grid_plot.plot(self.start_epoch) @@ -510,9 +486,6 @@ def show(self): self.val_scatter_plot.plot_active(self.start_id) - if self.sat_df is not None: - self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) - - - plt.show() + self.slider_plot.on_changed(self.update_epoch) + plt.show() \ No newline at end of file From 531bbc398b4ee1e221a386e0d27a4d1e7b539c71 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 13:09:44 +0100 Subject: [PATCH 05/21] try to fix --- stdata/vis/spacetime.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 00c5630..82908a0 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -392,9 +392,10 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta self.start_epoch = self.unique_epochs[-1] self.start_id = self.unique_ids[0] - def update_timeseries(self, _id): - self.time_series_plot.update(_id) - self.val_scatter_plot.update_active(_id) + + def update_timeseries(self): + self.time_series_plot.update() + self.val_scatter_plot.update_active() self.fig.canvas.draw_idle() def update_epoch(self, epoch): @@ -405,17 +406,24 @@ def update_epoch(self, epoch): self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) - def show(self): - self.fig, self.axs = plt.subplots( - 3, 2, figsize=(12, 6), gridspec_kw={"width_ratios": [1, 0.2], "height_ratios": [1, 1]} - ) +def show(self): + self.fig = plt.figure(figsize=(12, 6)) + + self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) + self.grid_plot_1_ax = self.fig.add_subplot( + self.gs[0:7, 0:2] + ) # first row, first col + self.grid_plot_2_ax = self.fig.add_subplot( + self.gs[0:7, 2:4] + ) # first row, second col + self.epoch_slider_ax = self.fig.add_subplot( + self.gs[7, 1:3] + ) # first row, second col + self.time_series_ax = self.fig.add_subplot(self.gs[8:11, :]) # full second row + self.scale_slider_ax = self.fig.add_subplot( + self.gs[11, 1:3] + ) # first row, second col - self.gs = self.fig.add_gridspec(12, 4, wspace=0.25, hspace=0.25) - self.grid_plot_1_ax = self.axs[0, 0] - self.grid_plot_2_ax = self.axs[0, 1] - self.epoch_slider_ax = self.axs[1, 0] - self.time_series_ax = self.axs[2, 0] - self.scale_slider_ax = self.axs[2, 1] if self.grid_plot_flag: self.val_grid_plot = ST_GridPlot( self.columns, From 73c046de0032baaa4e93f5ca2a54a67b10151cce Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 13:13:53 +0100 Subject: [PATCH 06/21] indentation error --- stdata/vis/spacetime.py | 146 ++++++++++++++++++++-------------------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 82908a0..ee66010 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -406,94 +406,94 @@ def update_epoch(self, epoch): self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) -def show(self): - self.fig = plt.figure(figsize=(12, 6)) - - self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) - self.grid_plot_1_ax = self.fig.add_subplot( - self.gs[0:7, 0:2] - ) # first row, first col - self.grid_plot_2_ax = self.fig.add_subplot( - self.gs[0:7, 2:4] - ) # first row, second col - self.epoch_slider_ax = self.fig.add_subplot( - self.gs[7, 1:3] - ) # first row, second col - self.time_series_ax = self.fig.add_subplot(self.gs[8:11, :]) # full second row - self.scale_slider_ax = self.fig.add_subplot( - self.gs[11, 1:3] - ) # first row, second col + def show(self): + self.fig = plt.figure(figsize=(12, 6)) + + self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) + self.grid_plot_1_ax = self.fig.add_subplot( + self.gs[0:7, 0:2] + ) # first row, first col + self.grid_plot_2_ax = self.fig.add_subplot( + self.gs[0:7, 2:4] + ) # first row, second col + self.epoch_slider_ax = self.fig.add_subplot( + self.gs[7, 1:3] + ) # first row, second col + self.time_series_ax = self.fig.add_subplot(self.gs[8:11, :]) # full second row + self.scale_slider_ax = self.fig.add_subplot( + self.gs[11, 1:3] + ) # first row, second col + + if self.grid_plot_flag: + self.val_grid_plot = ST_GridPlot( + self.columns, + "pred", + self.fig, + self.grid_plot_1_ax, + self.train_df, + self.test_df, + cax_on_right=False, + norm_on_training=True, + label="NO2", + geopandas_flag=self.geopandas_flag, + ) + self.val_grid_plot.setup() + + self.var_grid_plot = ST_GridPlot( + self.columns, + "var", + self.fig, + self.grid_plot_2_ax, + self.train_df, + self.test_df, + cax_on_right=False, + norm_on_training=True, + label="NO2", + geopandas_flag=self.geopandas_flag, + ) + self.var_grid_plot.setup() + else: + self.val_grid_plot = None + self.var_grid_plot = None - if self.grid_plot_flag: - self.val_grid_plot = ST_GridPlot( + self.val_scatter_plot = ST_ScatterPlot( self.columns, - "pred", self.fig, self.grid_plot_1_ax, + self.val_grid_plot, + self.grid_plot_flag, + self.update_timeseries, self.train_df, - self.test_df, - cax_on_right=False, - norm_on_training=True, - label="NO2", - geopandas_flag=self.geopandas_flag, ) - self.val_grid_plot.setup() + self.val_scatter_plot.setup() - self.var_grid_plot = ST_GridPlot( + self.slider_plot = ST_SliderPlot( + self.fig, self.epoch_slider_ax, self.unique_epochs, self.update_epoch + ) + self.slider_plot.setup(self.start_epoch) + + self.time_series_plot = ST_TimeSeriesPlot( self.columns, - "var", self.fig, - self.grid_plot_2_ax, + self.time_series_ax, self.train_df, self.test_df, - cax_on_right=False, - norm_on_training=True, - label="NO2", - geopandas_flag=self.geopandas_flag, + self.test_start, + self.grid_plot_flag, ) - self.var_grid_plot.setup() - else: - self.val_grid_plot = None - self.var_grid_plot = None - - self.val_scatter_plot = ST_ScatterPlot( - self.columns, - self.fig, - self.grid_plot_1_ax, - self.val_grid_plot, - self.grid_plot_flag, - self.update_timeseries, - self.train_df, - ) - self.val_scatter_plot.setup() + self.time_series_plot.setup() - self.slider_plot = ST_SliderPlot( - self.fig, self.epoch_slider_ax, self.unique_epochs, self.update_epoch - ) - self.slider_plot.setup(self.start_epoch) - - self.time_series_plot = ST_TimeSeriesPlot( - self.columns, - self.fig, - self.time_series_ax, - self.train_df, - self.test_df, - self.test_start, - self.grid_plot_flag, - ) - self.time_series_plot.setup() - - if self.grid_plot_flag: - self.val_grid_plot.plot(self.start_epoch) - self.var_grid_plot.plot(self.start_epoch) + if self.grid_plot_flag: + self.val_grid_plot.plot(self.start_epoch) + self.var_grid_plot.plot(self.start_epoch) - self.val_scatter_plot.plot(self.start_epoch) + self.val_scatter_plot.plot(self.start_epoch) - self.time_series_plot.plot_cur_epoch(self.start_epoch) - self.time_series_plot.plot(self.start_id) + self.time_series_plot.plot_cur_epoch(self.start_epoch) + self.time_series_plot.plot(self.start_id) - self.val_scatter_plot.plot_active(self.start_id) + self.val_scatter_plot.plot_active(self.start_id) - self.slider_plot.on_changed(self.update_epoch) + self.slider_plot.on_changed(self.update_epoch) - plt.show() \ No newline at end of file + plt.show() \ No newline at end of file From 33c937e3d66ac225d8df5c74467966f18ca65275 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 13:27:43 +0100 Subject: [PATCH 07/21] fix ST_SliderPlot --- stdata/vis/spacetime.py | 49 +++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index ee66010..532c06a 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -199,28 +199,35 @@ def plot(self, epoch): ) return self.grid_plot + + class ST_SliderPlot(object): + def __init__(self, fig, ax, unique_vals, callback): + self.fig = fig + self.ax = ax + self.unique_vals = unique_vals + self.callback = callback + + def set_text_format(self): + self.slider.valtext.set_text( + datetime.fromtimestamp(self.slider.val).strftime("%Y-%m-%d %H") + ) - -class ST_SliderPlot(object): - def __init__(self, fig, ax, unique_vals, callback): - self.fig = fig - self.ax = ax - self.unique_vals = unique_vals - self.callback = callback - - self.slider = widgets.Slider( - self.ax, - "Date", - np.min(self.unique_vals), - np.max(self.unique_vals), - valinit=unique_vals[0], - ) - self.slider.on_changed(self.update) - - def update(self, i): - cur_epoch_i = np.abs(self.unique_vals - i).argmin() - cur_epoch = self.unique_vals[cur_epoch_i] - self.callback(cur_epoch) + def setup(self, start_val): + self.slider = widgets.Slider( + self.ax, + "Date", + np.min(self.unique_vals), + np.max(self.unique_vals), + valinit=start_val, + ) + self.set_text_format() + self.slider.on_changed(self.update) + + def update(self, i): + cur_epoch_i = np.abs(self.unique_vals - i).argmin() + cur_epoch = self.unique_vals[cur_epoch_i] + self.set_text_format() + self.callback(cur_epoch) class ST_TimeSeriesPlot(object): From efcd6bf281476fdf7fa970a0430797c8f5c579ca Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 13:31:22 +0100 Subject: [PATCH 08/21] indentation error --- stdata/vis/spacetime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 532c06a..a57fb80 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -200,7 +200,7 @@ def plot(self, epoch): return self.grid_plot - class ST_SliderPlot(object): +class ST_SliderPlot(object): def __init__(self, fig, ax, unique_vals, callback): self.fig = fig self.ax = ax From 51d815f3d2c6f0686a0744093743f5f9efff54ff Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 14:33:41 +0100 Subject: [PATCH 09/21] try --- stdata/vis/spacetime.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index a57fb80..bc20d44 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -268,14 +268,14 @@ def get_time_series(self, _id, data): def plot(self, _id): epochs, var, pred, observed = self.get_time_series(_id, self.train_df) - self.var_plot = self.ax.fill_between(epochs, pred - 1.96*np.sqrt(var), pred + 1.96*np.sqrt(var)) - self.observed_scatter = self.ax.scatter(epochs, observed) - self.pred_plot = self.ax.plot(epochs, pred) + self.var_plot = self.ax.fill_between(epochs, pred - 1.96*np.sqrt(var), pred + 1.96*np.sqrt(var), alpha=0.3) + self.observed_scatter = self.ax.scatter(epochs, observed, alpha=0.5) + self.pred_plot = self.ax.plot(epochs, pred, linewidth=2) self.ax.set_xlim([self.min_test_epoch, self.max_test_epoch]) - self.min_line = self.ax.axvline(self.min_test_epoch) - self.max_line = self.ax.axvline(self.max_test_epoch) - self.test_start_line = self.ax.axvline(self.test_start_epoch) + self.min_line = self.ax.axvline(self.min_test_epoch, color="grey", linestyle="--") + self.max_line = self.ax.axvline(self.max_test_epoch, color="grey", linestyle="--") + self.test_start_line = self.ax.axvline(self.test_start_epoch, color="grey", linestyle="--") self.slider.update(self.train_df[self.columns["epoch"]][0]) @@ -284,6 +284,10 @@ def update(self, _id): self.plot(_id) self.update_cur_epoch(cur_epoch) + def update_cur_epoch(self, cur_epoch): + self.slider.val = cur_epoch + self.ax.set_xlabel(f"Epoch [{cur_epoch}]") + class ST_ScatterPlot(object): def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_df): self.columns = columns From bae1045c788dfc85abbfaaf140daae71a518df59 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 14:59:39 +0100 Subject: [PATCH 10/21] add .plot_cur_epoch back --- stdata/vis/spacetime.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index bc20d44..1e832b7 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -279,14 +279,18 @@ def plot(self, _id): self.slider.update(self.train_df[self.columns["epoch"]][0]) + def plot_cur_epoch(self, epoch): + self.cur_epoch_line = self.ax.axvline(epoch, ymin=0.25, ymax=1.0) + def update(self, _id): cur_epoch = self.slider.val self.plot(_id) self.update_cur_epoch(cur_epoch) - def update_cur_epoch(self, cur_epoch): + def update_cur_epoch(self, cur_epoch, _id): self.slider.val = cur_epoch self.ax.set_xlabel(f"Epoch [{cur_epoch}]") + self.plot(_id) class ST_ScatterPlot(object): def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_df): @@ -302,6 +306,8 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ self.norm = grid_plot.norm else: self.norm = matplotlib.colors.Normalize( + print("min: ", np.min(self.train_df[self.columns["pred"]])) + print("max: ", np.max(self.train_df[self.columns["pred"]])) vmin=np.min(self.train_df[self.columns["pred"]]), vmax=np.max(self.train_df[self.columns["pred"]]) ) From d87210da9cc3dda198695d2dbf2c3c3abd2d5b43 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 15:05:47 +0100 Subject: [PATCH 11/21] add comma --- stdata/vis/spacetime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 1e832b7..83f3196 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -306,8 +306,8 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ self.norm = grid_plot.norm else: self.norm = matplotlib.colors.Normalize( - print("min: ", np.min(self.train_df[self.columns["pred"]])) - print("max: ", np.max(self.train_df[self.columns["pred"]])) + print("min: ", np.min(self.train_df[self.columns["pred"]])), + print("max: ", np.max(self.train_df[self.columns["pred"]])), vmin=np.min(self.train_df[self.columns["pred"]]), vmax=np.max(self.train_df[self.columns["pred"]]) ) From 28ff70dba6583d673ce4fa2075408674a97971ab Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 15:30:48 +0100 Subject: [PATCH 12/21] vmin fix --- stdata/vis/spacetime.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 83f3196..961971f 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -305,11 +305,13 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ if grid_plot_flag: self.norm = grid_plot.norm else: - self.norm = matplotlib.colors.Normalize( - print("min: ", np.min(self.train_df[self.columns["pred"]])), - print("max: ", np.max(self.train_df[self.columns["pred"]])), - vmin=np.min(self.train_df[self.columns["pred"]]), vmax=np.max(self.train_df[self.columns["pred"]]) - ) + self.norm = matplotlib.colors.Normalize() + self.norm.vmin = np.min(self.train_df[self.columns["pred"]]) + self.norm.vmax = 1000 + + print("min: ", self.norm.vmin) + print("max: ", self.norm.vmax) + self.callback = callback From ec7c4a1b3dd78cc1b8879071cfc95e52dc6fa089 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 17:10:04 +0100 Subject: [PATCH 13/21] set_text_format --- stdata/vis/spacetime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 961971f..9d5c339 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -208,8 +208,8 @@ def __init__(self, fig, ax, unique_vals, callback): self.callback = callback def set_text_format(self): - self.slider.valtext.set_text( - datetime.fromtimestamp(self.slider.val).strftime("%Y-%m-%d %H") + self.valtext.set_text( + datetime.fromtimestamp(self.val).strftime("%Y-%m-%d %H") ) def setup(self, start_val): From 35d6b99c4a119810f71da9453e54eb1a867ce942 Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 17:27:41 +0100 Subject: [PATCH 14/21] revert --- stdata/vis/spacetime.py | 340 ++++++++++++++++------------------------ 1 file changed, 139 insertions(+), 201 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 9d5c339..bfe5aa0 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -1,14 +1,11 @@ import numpy as np import matplotlib import matplotlib.pyplot as plt -import matplotlib.widgets as widgets from matplotlib.widgets import Slider from mpl_toolkits.axes_grid1 import make_axes_locatable from datetime import datetime from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon -import matplotlib.colors -import matplotlib.patches as mpatches import geopandas import shapely @@ -16,19 +13,22 @@ def plot_polygon_collection( ax, geoms, - cmap="viridis", + norm, values=None, + colormap="Set1", facecolor=None, edgecolor=None, alpha=1.0, linewidth=1.0, - **kwargs + **kwargs, ): """ Plot a collection of Polygon geometries """ patches = [] for poly in geoms: - a = np.asarray(poly.exterior.xy) + #a = np.asarray(poly.geoms[0].exterior) + #a = np.asarray(poly.exterior) + a = np.asarray(poly.exterior.xy).T if poly.has_z: poly = shapely.geometry.Polygon(zip(*poly.geoms[0].exterior.xy)) @@ -41,18 +41,16 @@ def plot_polygon_collection( linewidth=linewidth, edgecolor=edgecolor, alpha=alpha, - cmap=cmap, + norm=norm, **kwargs ) if values is not None: patches.set_array(values) - + patches.set_cmap(colormap) ax.add_collection(patches, autolim=True) ax.autoscale_view() return patches - - class ST_GridPlot(object): def __init__( self, @@ -70,30 +68,23 @@ def __init__( self.columns = columns self.col = col self.geopandas_flag = geopandas_flag - self.fig = fig self.ax = ax - self.train_df = train_df self.test_df = test_df - self.norm_on_training = norm_on_training self.right_flag = cax_on_right self.label = label self.cmap = None - def get_spatial_slice(self, epoch): s = self.test_df[self.test_df[self.columns["epoch"]] == epoch] - if len(s) == 0: return None, None, None - return ( s[self.columns["x"]].astype(np.float32), s[self.columns["y"]].astype(np.float32), s[self.columns[self.col]].astype(np.float32), ) - def get_data(self, epoch): x_train, y_train, z_train = self.get_spatial_slice(epoch) if x_train is None: @@ -103,14 +94,12 @@ def get_data(self, epoch): z_train = np.array(z_train) s = np.c_[x_train, y_train] - n = int(np.sqrt(z_train.shape[0])) grid_index = np.lexsort((s[:, 0], s[:, 1])) s = s[grid_index, :] z_train = z_train[grid_index] z_train = (z_train).reshape(n, n) return s, z_train - def setup(self): df = self.test_df if self.norm_on_training: @@ -120,7 +109,6 @@ def setup(self): vmin=np.min(df[self.columns[self.col]]), vmax=np.max(df[self.columns[self.col]]), ) - # setup color bar self.divider = make_axes_locatable(self.ax) dir_str = "left" @@ -135,13 +123,11 @@ def update(self, epoch): if self.grid_plot is None: self.plot(epoch) return - df = self.test_df[self.test_df[self.columns["epoch"]] == epoch] df = df.sort_values(self.columns["id"]) self.grid_plot.set_array(df[self.columns[self.col]]) else: s, z_train = self.get_data(epoch) - if z_train is None: if hasattr(self, "grid_plot"): self.grid_plot.set_data([[]]) @@ -154,99 +140,75 @@ def update(self, epoch): return self.grid_plot else: return None - def plot(self, epoch): if self.geopandas_flag: df = self.test_df[self.test_df[self.columns["epoch"]] == epoch] - # If grid_plot is init with zero patches we cannot plot later if df.shape[0] == 0: self.grid_plot = None return - df = df.sort_values(self.columns["id"]) - geoms = df["geom"] - self.grid_plot = plot_polygon_collection( - self.ax, - geoms, - cmap=self.cmap, - norm=self.norm, - shading="auto", - edgecolor="white", - linewidths=0.2, - ) + geo_series = geopandas.GeoSeries(df["geom"]) + self.grid_plot = plot_polygon_collection(self.ax, geo_series, self.norm) + self.grid_plot.set_array(df[self.columns[self.col]]) else: s, z_train = self.get_data(epoch) if z_train is None: return - # get extents min_x = s[0, 0] min_y = s[0, 1] max_x = s[s.shape[0] - 1, 0] max_y = s[s.shape[0] - 1, 1] - self.grid_plot = self.ax.imshow( z_train, - origin="lowerleft", + origin="lower", cmap=self.cmap, norm=self.norm, aspect="auto", extent=[min_x, max_x, min_y, max_y], ) self.fig.colorbar( - self.grid_plot, cax=self.color_bar_ax, orientation="horizontal" + self.grid_plot, cax=self.color_bar_ax, orientation="vertical" ) - return self.grid_plot - class ST_SliderPlot(object): - def __init__(self, fig, ax, unique_vals, callback): - self.fig = fig - self.ax = ax - self.unique_vals = unique_vals - self.callback = callback - - def set_text_format(self): - self.valtext.set_text( - datetime.fromtimestamp(self.val).strftime("%Y-%m-%d %H") - ) - - def setup(self, start_val): - self.slider = widgets.Slider( - self.ax, - "Date", - np.min(self.unique_vals), - np.max(self.unique_vals), - valinit=start_val, - ) - self.set_text_format() - self.slider.on_changed(self.update) - - def update(self, i): - cur_epoch_i = np.abs(self.unique_vals - i).argmin() - cur_epoch = self.unique_vals[cur_epoch_i] - self.set_text_format() - self.callback(cur_epoch) - - + def __init__(self, fig, ax, unique_vals, callback): + self.fig = fig + self.ax = ax + self.unique_vals = unique_vals + self.callback = callback + def set_text_format(self): + datetime.fromtimestamp(1472860800).strftime("%Y-%m-%d %H") + self.slider.valtext.set_text( + datetime.fromtimestamp(self.slider.val).strftime("%Y-%m-%d %H") + ) + def setup(self, start_val): + self.slider = Slider( + self.ax, + "Date", + np.min(self.unique_vals), + np.max(self.unique_vals), + valinit=start_val, + ) + self.set_text_format() + self.slider.on_changed(self.update) + def update(self, i): + cur_epoch_i = np.abs(self.unique_vals - i).argmin() + cur_epoch = self.unique_vals[cur_epoch_i] + self.set_text_format() + self.callback(cur_epoch) class ST_TimeSeriesPlot(object): - def __init__( - self, columns, fig, ax, train_df, test_df, test_start, grid_plot_flag - ): + def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_flag): self.columns = columns self.fig = fig self.ax = ax - self.train_df = train_df self.test_df = test_df - self.min_test_epoch = np.min(self.train_df[columns["epoch"]]) self.max_test_epoch = np.max(self.train_df[columns["epoch"]]) self.test_start_epoch = test_start or self.min_test_epoch - self.slider = ST_SliderPlot(fig, ax, self.train_df["epoch"], self.update_cur_epoch) - def setup(self): pass @@ -262,115 +224,100 @@ def get_time_series(self, _id, data): var = d[self.columns["var"]].astype(np.float32) pred = d[self.columns["pred"]].astype(np.float32) observed = d[self.columns["observed"]].astype(np.float32) - return epochs, var, pred, observed - + def plot(self, _id): epochs, var, pred, observed = self.get_time_series(_id, self.train_df) - self.var_plot = self.ax.fill_between(epochs, pred - 1.96*np.sqrt(var), pred + 1.96*np.sqrt(var), alpha=0.3) - self.observed_scatter = self.ax.scatter(epochs, observed, alpha=0.5) - self.pred_plot = self.ax.plot(epochs, pred, linewidth=2) + self.var_plot = self.ax.fill_between(epochs, pred - 1.96*np.sqrt(var), pred + 1.96*np.sqrt(var)) + self.observed_scatter = self.ax.scatter(epochs, observed) + self.pred_plot = self.ax.plot(epochs, pred) self.ax.set_xlim([self.min_test_epoch, self.max_test_epoch]) - - self.min_line = self.ax.axvline(self.min_test_epoch, color="grey", linestyle="--") - self.max_line = self.ax.axvline(self.max_test_epoch, color="grey", linestyle="--") - self.test_start_line = self.ax.axvline(self.test_start_epoch, color="grey", linestyle="--") - - self.slider.update(self.train_df[self.columns["epoch"]][0]) + self.min_line = self.ax.axvline(self.min_test_epoch) + self.max_line = self.ax.axvline(self.max_test_epoch) + self.test_start_line = self.ax.axvline(self.test_start_epoch) def plot_cur_epoch(self, epoch): self.cur_epoch_line = self.ax.axvline(epoch, ymin=0.25, ymax=1.0) + def update_cur_epoch(self, epoch): + self.cur_epoch_line.remove() + self.plot_cur_epoch(epoch) + def update(self, _id): - cur_epoch = self.slider.val - self.plot(_id) - self.update_cur_epoch(cur_epoch) + try: + self.var_plot.remove() + self.observed_scatter.remove() + self.pred_plot[0].remove() + self.min_line.remove() + self.max_line.remove() + except ValueError as e: + # already been removed so need to remove again + pass - def update_cur_epoch(self, cur_epoch, _id): - self.slider.val = cur_epoch - self.ax.set_xlabel(f"Epoch [{cur_epoch}]") self.plot(_id) - class ST_ScatterPlot(object): def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_df): self.columns = columns - self.fig = fig self.ax = ax - self.train_df = train_df self.cmap = None - if grid_plot_flag: self.norm = grid_plot.norm else: - self.norm = matplotlib.colors.Normalize() - self.norm.vmin = np.min(self.train_df[self.columns["pred"]]) - self.norm.vmax = 1000 - - print("min: ", self.norm.vmin) - print("max: ", self.norm.vmax) - + print("min: ", np.min(self.train_df[self.columns["pred"]])) + print("max: ", np.max(self.train_df[self.columns["pred"]])) + self.norm = matplotlib.colors.Normalize( + vmin=np.min(self.train_df[self.columns["pred"]]), vmax=1000 + ) self.callback = callback - self.cur_epoch = None - self.cur_id = None - def setup(self): self.fig.canvas.mpl_connect("button_release_event", self.on_plot_hover) - def get_closest_observed(self, p): d = np.array(self.train_df[[self.columns["x"], self.columns["y"]]]).astype( np.float32 ) dists = np.sum((d - p) ** 2, axis=1) i = np.argmin(dists) - # if dists[i] <= 1e-4: + #if dists[i] <= 1e-4: if dists[i] <= 0.02: return self.train_df.iloc[i][self.columns["id"]] else: return None - def on_plot_hover(self, event): if True or event.inaxes is self.ax: p = event.xdata, event.ydata _id = self.get_closest_observed(p) if _id is not None: self.callback(_id) - def get_spatial_slice(self, epoch, data, _id=None): s = data[data[self.columns["epoch"]] == epoch] if _id: s = s[s[self.columns["id"]] == _id] - return ( s[self.columns["x"]].astype(np.float32), s[self.columns["y"]].astype(np.float32), s[self.columns["pred"]].astype(np.float32), ) - def plot(self, epoch): self.cur_epoch = epoch x, y, z = self.get_spatial_slice(epoch, self.train_df) - self.scatter = self.ax.scatter( x, y, c=z, norm=self.norm, cmap=self.cmap, edgecolors="w" ) - def plot_active(self, _id): self.cur_id = _id x, y, z = self.get_spatial_slice(self.cur_epoch, self.train_df, _id) self.active_scatter = self.ax.scatter( x, y, c=z, norm=self.norm, cmap=self.cmap, edgecolors="y" ) - def update(self, epoch): self.scatter.remove() self.plot(epoch) self.update_active(self.cur_id) - def update_active(self, _id): self.cur_id = _id self.active_scatter.remove() @@ -391,7 +338,6 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta } self.columns = columns self.geopandas_flag = geopandas_flag - self.train_df = train_df self.test_df = test_df self.sat_df = sat_df @@ -405,114 +351,106 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta self.test_start = test_start else: self.test_start = self.min_time - self.unique_epochs = np.unique(self.train_df[columns["epoch"]]) self.unique_ids = np.unique(self.train_df[columns["id"]]) - self.start_epoch = self.unique_epochs[-1] self.start_id = self.unique_ids[0] - def update_timeseries(self): - self.time_series_plot.update() - self.val_scatter_plot.update_active() + def update_timeseries(self, _id): + self.time_series_plot.update(_id) + self.val_scatter_plot.update_active(_id) self.fig.canvas.draw_idle() def update_epoch(self, epoch): if self.grid_plot_flag: self.val_grid_plot.update(epoch) self.var_grid_plot.update(epoch) - self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) - - def show(self): - self.fig = plt.figure(figsize=(12, 6)) - - self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) - self.grid_plot_1_ax = self.fig.add_subplot( - self.gs[0:7, 0:2] - ) # first row, first col - self.grid_plot_2_ax = self.fig.add_subplot( - self.gs[0:7, 2:4] - ) # first row, second col - self.epoch_slider_ax = self.fig.add_subplot( - self.gs[7, 1:3] - ) # first row, second col - self.time_series_ax = self.fig.add_subplot(self.gs[8:11, :]) # full second row - self.scale_slider_ax = self.fig.add_subplot( - self.gs[11, 1:3] - ) # first row, second col - - if self.grid_plot_flag: - self.val_grid_plot = ST_GridPlot( - self.columns, - "pred", - self.fig, - self.grid_plot_1_ax, - self.train_df, - self.test_df, - cax_on_right=False, - norm_on_training=True, - label="NO2", - geopandas_flag=self.geopandas_flag, - ) - self.val_grid_plot.setup() - - self.var_grid_plot = ST_GridPlot( - self.columns, - "var", - self.fig, - self.grid_plot_2_ax, - self.train_df, - self.test_df, - cax_on_right=False, - norm_on_training=True, - label="NO2", - geopandas_flag=self.geopandas_flag, - ) - self.var_grid_plot.setup() - else: - self.val_grid_plot = None - self.var_grid_plot = None - self.val_scatter_plot = ST_ScatterPlot( + def show(self): + self.fig = plt.figure(figsize=(12, 6)) + + self.gs = matplotlib.gridspec.GridSpec(12, 4, wspace=0.25, hspace=0.25) + self.grid_plot_1_ax = self.fig.add_subplot( + self.gs[0:7, 0:2] + ) # first row, first col + self.grid_plot_2_ax = self.fig.add_subplot( + self.gs[0:7, 2:4] + ) # first row, second col + self.epoch_slider_ax = self.fig.add_subplot( + self.gs[7, 1:3] + ) # first row, second col + self.time_series_ax = self.fig.add_subplot(self.gs[8:11, :]) # full second row + self.scale_slider_ax = self.fig.add_subplot( + self.gs[11, 1:3] + ) # first row, second col + if self.grid_plot_flag: + self.val_grid_plot = ST_GridPlot( self.columns, + "pred", self.fig, self.grid_plot_1_ax, - self.val_grid_plot, - self.grid_plot_flag, - self.update_timeseries, self.train_df, + self.test_df, + cax_on_right=False, + norm_on_training=True, + label="NO2", + geopandas_flag=self.geopandas_flag, ) - self.val_scatter_plot.setup() - - self.slider_plot = ST_SliderPlot( - self.fig, self.epoch_slider_ax, self.unique_epochs, self.update_epoch - ) - self.slider_plot.setup(self.start_epoch) - - self.time_series_plot = ST_TimeSeriesPlot( + self.val_grid_plot.setup() + self.var_grid_plot = ST_GridPlot( self.columns, + "var", self.fig, - self.time_series_ax, + self.grid_plot_2_ax, self.train_df, self.test_df, - self.test_start, - self.grid_plot_flag, + cax_on_right=False, + norm_on_training=True, + label="NO2", + geopandas_flag=self.geopandas_flag, ) - self.time_series_plot.setup() - - if self.grid_plot_flag: - self.val_grid_plot.plot(self.start_epoch) - self.var_grid_plot.plot(self.start_epoch) + self.var_grid_plot.setup() + else: + self.val_grid_plot = None + self.var_grid_plot = None + self.val_scatter_plot = ST_ScatterPlot( + self.columns, + self.fig, + self.grid_plot_1_ax, + self.val_grid_plot, + self.grid_plot_flag, + self.update_timeseries, + self.train_df, + ) + self.val_scatter_plot.setup() + self.slider_plot = ST_SliderPlot( + self.fig, self.epoch_slider_ax, self.unique_epochs, self.update_epoch + ) + self.slider_plot.setup(self.start_epoch) + self.time_series_plot = ST_TimeSeriesPlot( + self.columns, + self.fig, + self.time_series_ax, + self.train_df, + self.test_df, + self.test_start, + self.grid_plot_flag, + ) + self.time_series_plot.setup() - self.val_scatter_plot.plot(self.start_epoch) - self.time_series_plot.plot_cur_epoch(self.start_epoch) - self.time_series_plot.plot(self.start_id) + if self.grid_plot_flag: + self.val_grid_plot.plot(self.start_epoch) + self.var_grid_plot.plot(self.start_epoch) + self.val_scatter_plot.plot(self.start_epoch) + self.time_series_plot.plot_cur_epoch(self.start_epoch) + self.time_series_plot.plot(self.start_id) + self.val_scatter_plot.plot_active(self.start_id) - self.val_scatter_plot.plot_active(self.start_id) + if self.sat_df is not None: + self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) - self.slider_plot.on_changed(self.update_epoch) - plt.show() \ No newline at end of file + plt.show() \ No newline at end of file From 7c3bd49515cfd04be4e5d4f701f64ac71c211e8e Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 20:20:29 +0100 Subject: [PATCH 15/21] support polygons with Z values --- stdata/vis/spacetime.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index bfe5aa0..1279313 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -20,18 +20,15 @@ def plot_polygon_collection( edgecolor=None, alpha=1.0, linewidth=1.0, - **kwargs, + **kwargs ): - """ Plot a collection of Polygon geometries """ + """Plot a collection of Polygon geometries""" patches = [] for poly in geoms: - #a = np.asarray(poly.geoms[0].exterior) - #a = np.asarray(poly.exterior) a = np.asarray(poly.exterior.xy).T - if poly.has_z: - poly = shapely.geometry.Polygon(zip(*poly.geoms[0].exterior.xy)) + poly = shapely.geometry.Polygon(poly.exterior.xy, z=poly.z) patches.append(Polygon(a)) @@ -48,6 +45,7 @@ def plot_polygon_collection( if values is not None: patches.set_array(values) patches.set_cmap(colormap) + ax.add_collection(patches, autolim=True) ax.autoscale_view() return patches From 363a52db5908fbcb004136e5c0e3eba2082909ad Mon Sep 17 00:00:00 2001 From: Sueda Ciftci Date: Thu, 14 Sep 2023 22:44:05 +0100 Subject: [PATCH 16/21] deneme --- stdata/vis/spacetime.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 1279313..0951688 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -49,6 +49,7 @@ def plot_polygon_collection( ax.add_collection(patches, autolim=True) ax.autoscale_view() return patches + class ST_GridPlot(object): def __init__( self, @@ -58,7 +59,7 @@ def __init__( ax, train_df, test_df, - cax_on_right, + cax_on_right=True, norm_on_training=True, label="", geopandas_flag=False, @@ -74,6 +75,8 @@ def __init__( self.right_flag = cax_on_right self.label = label self.cmap = None + self.grid_plot = None + def get_spatial_slice(self, epoch): s = self.test_df[self.test_df[self.columns["epoch"]] == epoch] if len(s) == 0: @@ -83,6 +86,7 @@ def get_spatial_slice(self, epoch): s[self.columns["y"]].astype(np.float32), s[self.columns[self.col]].astype(np.float32), ) + def get_data(self, epoch): x_train, y_train, z_train = self.get_spatial_slice(epoch) if x_train is None: @@ -98,6 +102,7 @@ def get_data(self, epoch): z_train = z_train[grid_index] z_train = (z_train).reshape(n, n) return s, z_train + def setup(self): df = self.test_df if self.norm_on_training: @@ -107,6 +112,7 @@ def setup(self): vmin=np.min(df[self.columns[self.col]]), vmax=np.max(df[self.columns[self.col]]), ) + # setup color bar self.divider = make_axes_locatable(self.ax) dir_str = "left" @@ -134,10 +140,8 @@ def update(self, epoch): self.grid_plot.set_data(z_train) else: self.plot(epoch) - if hasattr(self, "grid_plot"): - return self.grid_plot - else: - return None + self.fig.canvas.draw() + def plot(self, epoch): if self.geopandas_flag: df = self.test_df[self.test_df[self.columns["epoch"]] == epoch] @@ -169,7 +173,10 @@ def plot(self, epoch): self.fig.colorbar( self.grid_plot, cax=self.color_bar_ax, orientation="vertical" ) + self.ax.set_title(f"Epoch {epoch} {self.label}") return self.grid_plot + + class ST_SliderPlot(object): def __init__(self, fig, ax, unique_vals, callback): self.fig = fig @@ -196,6 +203,9 @@ def update(self, i): cur_epoch = self.unique_vals[cur_epoch_i] self.set_text_format() self.callback(cur_epoch) + + + class ST_TimeSeriesPlot(object): def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_flag): self.columns = columns @@ -362,7 +372,6 @@ def update_timeseries(self, _id): def update_epoch(self, epoch): if self.grid_plot_flag: self.val_grid_plot.update(epoch) - self.var_grid_plot.update(epoch) self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) From f6fe414711779915538d31b99b06ea5c6ae106a7 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Mon, 18 Sep 2023 11:25:27 +0100 Subject: [PATCH 17/21] add back missing var grid plot update --- stdata/vis/spacetime.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 0951688..5f60d98 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -261,6 +261,7 @@ def update(self, _id): self.max_line.remove() except ValueError as e: # already been removed so need to remove again + print(e) pass self.plot(_id) @@ -372,6 +373,7 @@ def update_timeseries(self, _id): def update_epoch(self, epoch): if self.grid_plot_flag: self.val_grid_plot.update(epoch) + self.var_grid_plot.update(epoch) self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) @@ -460,4 +462,4 @@ def show(self): self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) - plt.show() \ No newline at end of file + plt.show() From 671f32dda27567d5681b2d961209621004ef95e3 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Tue, 18 Jun 2024 10:08:19 +0100 Subject: [PATCH 18/21] plot satellite first so it doesnt overlap laqn --- stdata/vis/spacetime.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index 5f60d98..9b5160c 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -449,6 +449,8 @@ def show(self): ) self.time_series_plot.setup() + if self.sat_df is not None: + self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) if self.grid_plot_flag: self.val_grid_plot.plot(self.start_epoch) @@ -458,8 +460,5 @@ def show(self): self.time_series_plot.plot(self.start_id) self.val_scatter_plot.plot_active(self.start_id) - if self.sat_df is not None: - self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) - plt.show() From b38ee482d74e6c5c74753265452a0c28956d7088 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Tue, 24 Sep 2024 16:00:17 +0300 Subject: [PATCH 19/21] remove requirement that df is numpy in spatial_cross_validation --- stdata/model_selection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stdata/model_selection.py b/stdata/model_selection.py index 607fda7..ad99012 100644 --- a/stdata/model_selection.py +++ b/stdata/model_selection.py @@ -129,7 +129,6 @@ def spatial_k_fold_generator(df, num_folds, group_col='group'): """ A wrapper so that spatial k fold can be used with the same syntax as sklearn k fold""" class _gen(): def split(self, df_to_split): - df_to_split = np.array(df_to_split) for k in range(num_folds): train_index = (df[group_col] != k) test_index = (df[group_col] == k) From a8a719cd8c58ec59c83cfbcded24eed6ba90f907 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Thu, 26 Sep 2024 17:08:17 +0300 Subject: [PATCH 20/21] simplify equal spatial clusters --- stdata/model_selection.py | 72 +-------------------------------------- 1 file changed, 1 insertion(+), 71 deletions(-) diff --git a/stdata/model_selection.py b/stdata/model_selection.py index ad99012..00aceb6 100644 --- a/stdata/model_selection.py +++ b/stdata/model_selection.py @@ -149,77 +149,7 @@ def _equal_k_means(df, n_clusters=5, verbose=False): m = cluster.KMeans(n_clusters=n_clusters).fit(X) - dists = pairwise_distances(m.cluster_centers_, X) - - clusters = {c: [] for c in range(n_clusters)} - - assigned_ids = [] - - N = X.shape[0] - - dists_all = [dists[c].argsort() for c in range(n_clusters)] - - # each step assigns n_cluster points to assigned_ids - num_iters = int(np.ceil(N/float(n_clusters))) - - if verbose: - bar = tqdm(total=num_iters) - - for i in range(num_iters): - for c in range(n_clusters): - - # find closest point - all_closest_points = dists_all[c] - - closest_points = all_closest_points[~np.isin(all_closest_points,assigned_ids)] - closest_point = closest_points[0] - closest_point_dist = dists[c][closest_point] - - # find the closest cluster for closest point - closest_cluster = dists[:, closest_point].argsort()[0] - - if c != closest_cluster: - # find assigned point in cluster that is closest to c - closest_points_in_new_cluster = dists[c][ - clusters[closest_cluster] - ].argsort() - - if len(closest_points_in_new_cluster) == 0: - clusters[c].append(closest_point) - else: - closest_point_in_new_cluster = clusters[closest_cluster][closest_points_in_new_cluster[0]] - - if dists[c][closest_point_in_new_cluster] < closest_point_dist: - clusters[closest_cluster].remove(closest_point_in_new_cluster) - clusters[closest_cluster].append(closest_point) - clusters[c].append(closest_point_in_new_cluster) - else: - # do nothing - clusters[c].append(closest_point) - else: - clusters[c].append(closest_point) - - assigned_ids.append(closest_point) - - - if len(assigned_ids) == N: - break - - if len(assigned_ids) == N: - break - - if verbose: - bar.update(1) - - cluster_df = pd.DataFrame( - [[i, c] for c, a in clusters.items() for i in a], - columns=['__index_cluster', 'label'] - ) - - df = df.merge(cluster_df, left_on=['__index'], right_on=['__index_cluster'], how='left', suffixes=[None, '_y']) - df = df.drop(columns=['__index', '__index_cluster']) - - df['k_means_label'] = m.labels_ + df['label'] = m.labels_ return df From 0541267ed1210b8af701831280952df8916fc447 Mon Sep 17 00:00:00 2001 From: Oliver Hamelijnck Date: Thu, 26 Sep 2024 18:54:49 +0300 Subject: [PATCH 21/21] fix equal kmeans --- stdata/model_selection.py | 40 ++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/stdata/model_selection.py b/stdata/model_selection.py index 00aceb6..8862242 100644 --- a/stdata/model_selection.py +++ b/stdata/model_selection.py @@ -8,6 +8,8 @@ import sklearn from sklearn import cluster from sklearn.metrics import pairwise_distances +from scipy.spatial.distance import cdist +from scipy.optimize import linear_sum_assignment from tqdm import tqdm @@ -137,20 +139,36 @@ def split(self, df_to_split): return _gen() -def _equal_k_means(df, n_clusters=5, verbose=False): - """ - This uses a greedy algorithm, and so although each cluster will be equal, it may not be spatially aligned. - """ - df = df.copy().reset_index() - - df['__index'] = df.index - - X = np.array(df[['lat', 'lon']]) + +def _equal_k_means(df, n_clusters, verbose=False): + """ + Taken from Eyal Shulman implementation https://stackoverflow.com/questions/5452576/k-means-algorithm-variation-with-equal-cluster-size - m = cluster.KMeans(n_clusters=n_clusters).fit(X) + Edited to make it work with slightly uneven clusters + """ + + df = df.copy() + points = np.array(df[['lat', 'lon']]) + n_points = points.shape[0] - df['label'] = m.labels_ + num_to_remove = int(abs((n_points-n_clusters*np.ceil(n_points/n_clusters)))) + points_removed = points[:num_to_remove] + X = points[num_to_remove:n_points] + + cluster_size = int(np.ceil(len(X)/n_clusters)) + kmeans = cluster.KMeans(n_clusters) + kmeans.fit(X) + k_centers = kmeans.cluster_centers_ + centers = k_centers + centers = centers.reshape(-1, 1, X.shape[-1]).repeat(cluster_size, 1).reshape(-1, X.shape[-1]) + distance_matrix = cdist(X, centers) + clusters = linear_sum_assignment(distance_matrix)[1]//cluster_size + # add points removed back to there closest points + distance_matrix = cdist(points_removed, k_centers) + points_to_add_back = np.argmax(distance_matrix, axis=1) + clusters = np.hstack([points_to_add_back, clusters]) + df['label'] = clusters return df def equal_spatial_clusters(df, n_clusters=5, lat_col='lat', lon_col='lon', group_col='label', verbose=False):