diff --git a/stdata/model_selection.py b/stdata/model_selection.py index 607fda7..8862242 100644 --- a/stdata/model_selection.py +++ b/stdata/model_selection.py @@ -8,6 +8,8 @@ import sklearn from sklearn import cluster from sklearn.metrics import pairwise_distances +from scipy.spatial.distance import cdist +from scipy.optimize import linear_sum_assignment from tqdm import tqdm @@ -129,7 +131,6 @@ def spatial_k_fold_generator(df, num_folds, group_col='group'): """ A wrapper so that spatial k fold can be used with the same syntax as sklearn k fold""" class _gen(): def split(self, df_to_split): - df_to_split = np.array(df_to_split) for k in range(num_folds): train_index = (df[group_col] != k) test_index = (df[group_col] == k) @@ -138,90 +139,36 @@ def split(self, df_to_split): return _gen() -def _equal_k_means(df, n_clusters=5, verbose=False): - """ - This uses a greedy algorithm, and so although each cluster will be equal, it may not be spatially aligned. - """ - df = df.copy().reset_index() - - df['__index'] = df.index - - X = np.array(df[['lat', 'lon']]) - - m = cluster.KMeans(n_clusters=n_clusters).fit(X) - - dists = pairwise_distances(m.cluster_centers_, X) - - clusters = {c: [] for c in range(n_clusters)} - - assigned_ids = [] - - N = X.shape[0] - dists_all = [dists[c].argsort() for c in range(n_clusters)] +def _equal_k_means(df, n_clusters, verbose=False): + """ + Taken from Eyal Shulman implementation https://stackoverflow.com/questions/5452576/k-means-algorithm-variation-with-equal-cluster-size - # each step assigns n_cluster points to assigned_ids - num_iters = int(np.ceil(N/float(n_clusters))) - - if verbose: - bar = tqdm(total=num_iters) - - for i in range(num_iters): - for c in range(n_clusters): - - # find closest point - all_closest_points = dists_all[c] - - closest_points = all_closest_points[~np.isin(all_closest_points,assigned_ids)] - closest_point = closest_points[0] - closest_point_dist = dists[c][closest_point] - - # find the closest cluster for closest point - closest_cluster = dists[:, closest_point].argsort()[0] - - if c != closest_cluster: - # find assigned point in cluster that is closest to c - closest_points_in_new_cluster = dists[c][ - clusters[closest_cluster] - ].argsort() - - if len(closest_points_in_new_cluster) == 0: - clusters[c].append(closest_point) - else: - closest_point_in_new_cluster = clusters[closest_cluster][closest_points_in_new_cluster[0]] - - if dists[c][closest_point_in_new_cluster] < closest_point_dist: - clusters[closest_cluster].remove(closest_point_in_new_cluster) - clusters[closest_cluster].append(closest_point) - clusters[c].append(closest_point_in_new_cluster) - else: - # do nothing - clusters[c].append(closest_point) - else: - clusters[c].append(closest_point) - - assigned_ids.append(closest_point) - - - if len(assigned_ids) == N: - break - - if len(assigned_ids) == N: - break + Edited to make it work with slightly uneven clusters + """ - if verbose: - bar.update(1) - - cluster_df = pd.DataFrame( - [[i, c] for c, a in clusters.items() for i in a], - columns=['__index_cluster', 'label'] - ) + df = df.copy() + points = np.array(df[['lat', 'lon']]) + n_points = points.shape[0] - df = df.merge(cluster_df, left_on=['__index'], right_on=['__index_cluster'], how='left', suffixes=[None, '_y']) - df = df.drop(columns=['__index', '__index_cluster']) - - df['k_means_label'] = m.labels_ + num_to_remove = int(abs((n_points-n_clusters*np.ceil(n_points/n_clusters)))) + points_removed = points[:num_to_remove] + X = points[num_to_remove:n_points] + + cluster_size = int(np.ceil(len(X)/n_clusters)) + kmeans = cluster.KMeans(n_clusters) + kmeans.fit(X) + k_centers = kmeans.cluster_centers_ + centers = k_centers + centers = centers.reshape(-1, 1, X.shape[-1]).repeat(cluster_size, 1).reshape(-1, X.shape[-1]) + distance_matrix = cdist(X, centers) + clusters = linear_sum_assignment(distance_matrix)[1]//cluster_size + # add points removed back to there closest points + distance_matrix = cdist(points_removed, k_centers) + points_to_add_back = np.argmax(distance_matrix, axis=1) + clusters = np.hstack([points_to_add_back, clusters]) + df['label'] = clusters return df def equal_spatial_clusters(df, n_clusters=5, lat_col='lat', lon_col='lon', group_col='label', verbose=False): diff --git a/stdata/vis/spacetime.py b/stdata/vis/spacetime.py index edcff6a..9b5160c 100644 --- a/stdata/vis/spacetime.py +++ b/stdata/vis/spacetime.py @@ -22,15 +22,13 @@ def plot_polygon_collection( linewidth=1.0, **kwargs ): - """ Plot a collection of Polygon geometries """ + """Plot a collection of Polygon geometries""" patches = [] for poly in geoms: - #a = np.asarray(poly.geoms[0].exterior) - a = np.asarray(poly.exterior) - + a = np.asarray(poly.exterior.xy).T if poly.has_z: - poly = shapely.geometry.Polygon(zip(*poly.geoms[0].exterior.xy)) + poly = shapely.geometry.Polygon(poly.exterior.xy, z=poly.z) patches.append(Polygon(a)) @@ -52,7 +50,6 @@ def plot_polygon_collection( ax.autoscale_view() return patches - class ST_GridPlot(object): def __init__( self, @@ -62,7 +59,7 @@ def __init__( ax, train_df, test_df, - cax_on_right, + cax_on_right=True, norm_on_training=True, label="", geopandas_flag=False, @@ -70,24 +67,20 @@ def __init__( self.columns = columns self.col = col self.geopandas_flag = geopandas_flag - self.fig = fig self.ax = ax - self.train_df = train_df self.test_df = test_df - self.norm_on_training = norm_on_training self.right_flag = cax_on_right self.label = label self.cmap = None + self.grid_plot = None def get_spatial_slice(self, epoch): s = self.test_df[self.test_df[self.columns["epoch"]] == epoch] - if len(s) == 0: return None, None, None - return ( s[self.columns["x"]].astype(np.float32), s[self.columns["y"]].astype(np.float32), @@ -103,7 +96,6 @@ def get_data(self, epoch): z_train = np.array(z_train) s = np.c_[x_train, y_train] - n = int(np.sqrt(z_train.shape[0])) grid_index = np.lexsort((s[:, 0], s[:, 1])) s = s[grid_index, :] @@ -126,7 +118,6 @@ def setup(self): dir_str = "left" if self.right_flag: dir_str = "right" - self.color_bar_ax = self.divider.append_axes(dir_str, size="5%", pad=0.05) def update(self, epoch): @@ -136,13 +127,11 @@ def update(self, epoch): if self.grid_plot is None: self.plot(epoch) return - df = self.test_df[self.test_df[self.columns["epoch"]] == epoch] df = df.sort_values(self.columns["id"]) self.grid_plot.set_array(df[self.columns[self.col]]) else: s, z_train = self.get_data(epoch) - if z_train is None: if hasattr(self, "grid_plot"): self.grid_plot.set_data([[]]) @@ -151,20 +140,15 @@ def update(self, epoch): self.grid_plot.set_data(z_train) else: self.plot(epoch) - if hasattr(self, "grid_plot"): - return self.grid_plot - else: - return None + self.fig.canvas.draw() def plot(self, epoch): if self.geopandas_flag: df = self.test_df[self.test_df[self.columns["epoch"]] == epoch] - # If grid_plot is init with zero patches we cannot plot later if df.shape[0] == 0: self.grid_plot = None return - df = df.sort_values(self.columns["id"]) geo_series = geopandas.GeoSeries(df["geom"]) self.grid_plot = plot_polygon_collection(self.ax, geo_series, self.norm) @@ -173,13 +157,11 @@ def plot(self, epoch): s, z_train = self.get_data(epoch) if z_train is None: return - # get extents min_x = s[0, 0] min_y = s[0, 1] max_x = s[s.shape[0] - 1, 0] max_y = s[s.shape[0] - 1, 1] - self.grid_plot = self.ax.imshow( z_train, origin="lower", @@ -191,23 +173,21 @@ def plot(self, epoch): self.fig.colorbar( self.grid_plot, cax=self.color_bar_ax, orientation="vertical" ) - + self.ax.set_title(f"Epoch {epoch} {self.label}") return self.grid_plot - - + + class ST_SliderPlot(object): def __init__(self, fig, ax, unique_vals, callback): self.fig = fig self.ax = ax self.unique_vals = unique_vals self.callback = callback - def set_text_format(self): datetime.fromtimestamp(1472860800).strftime("%Y-%m-%d %H") self.slider.valtext.set_text( datetime.fromtimestamp(self.slider.val).strftime("%Y-%m-%d %H") ) - def setup(self, start_val): self.slider = Slider( self.ax, @@ -218,7 +198,6 @@ def setup(self, start_val): ) self.set_text_format() self.slider.on_changed(self.update) - def update(self, i): cur_epoch_i = np.abs(self.unique_vals - i).argmin() cur_epoch = self.unique_vals[cur_epoch_i] @@ -226,15 +205,14 @@ def update(self, i): self.callback(cur_epoch) + class ST_TimeSeriesPlot(object): def __init__(self, columns, fig, ax, train_df, test_df, test_start, grid_plot_flag): self.columns = columns self.fig = fig self.ax = ax - self.train_df = train_df self.test_df = test_df - self.min_test_epoch = np.min(self.train_df[columns["epoch"]]) self.max_test_epoch = np.max(self.train_df[columns["epoch"]]) self.test_start_epoch = test_start or self.min_test_epoch @@ -254,9 +232,8 @@ def get_time_series(self, _id, data): var = d[self.columns["var"]].astype(np.float32) pred = d[self.columns["pred"]].astype(np.float32) observed = d[self.columns["observed"]].astype(np.float32) - return epochs, var, pred, observed - + def plot(self, _id): epochs, var, pred, observed = self.get_time_series(_id, self.train_df) @@ -264,7 +241,6 @@ def plot(self, _id): self.observed_scatter = self.ax.scatter(epochs, observed) self.pred_plot = self.ax.plot(epochs, pred) self.ax.set_xlim([self.min_test_epoch, self.max_test_epoch]) - self.min_line = self.ax.axvline(self.min_test_epoch) self.max_line = self.ax.axvline(self.max_test_epoch) self.test_start_line = self.ax.axvline(self.test_start_epoch) @@ -277,24 +253,25 @@ def update_cur_epoch(self, epoch): self.plot_cur_epoch(epoch) def update(self, _id): - self.var_plot.remove() - self.observed_scatter.remove() - self.ax.lines.remove(self.pred_plot[0]) - self.min_line.remove() - self.max_line.remove() - self.plot(_id) - + try: + self.var_plot.remove() + self.observed_scatter.remove() + self.pred_plot[0].remove() + self.min_line.remove() + self.max_line.remove() + except ValueError as e: + # already been removed so need to remove again + print(e) + pass + self.plot(_id) class ST_ScatterPlot(object): def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_df): self.columns = columns - self.fig = fig self.ax = ax - self.train_df = train_df self.cmap = None - if grid_plot_flag: self.norm = grid_plot.norm else: @@ -305,12 +282,9 @@ def __init__(self, columns, fig, ax, grid_plot, grid_plot_flag, callback, train_ ) self.callback = callback - self.cur_epoch = None - def setup(self): self.fig.canvas.mpl_connect("button_release_event", self.on_plot_hover) - def get_closest_observed(self, p): d = np.array(self.train_df[[self.columns["x"], self.columns["y"]]]).astype( np.float32 @@ -322,45 +296,37 @@ def get_closest_observed(self, p): return self.train_df.iloc[i][self.columns["id"]] else: return None - def on_plot_hover(self, event): if True or event.inaxes is self.ax: p = event.xdata, event.ydata _id = self.get_closest_observed(p) if _id is not None: self.callback(_id) - def get_spatial_slice(self, epoch, data, _id=None): s = data[data[self.columns["epoch"]] == epoch] if _id: s = s[s[self.columns["id"]] == _id] - return ( s[self.columns["x"]].astype(np.float32), s[self.columns["y"]].astype(np.float32), s[self.columns["pred"]].astype(np.float32), ) - def plot(self, epoch): self.cur_epoch = epoch x, y, z = self.get_spatial_slice(epoch, self.train_df) - self.scatter = self.ax.scatter( x, y, c=z, norm=self.norm, cmap=self.cmap, edgecolors="w" ) - def plot_active(self, _id): self.cur_id = _id x, y, z = self.get_spatial_slice(self.cur_epoch, self.train_df, _id) self.active_scatter = self.ax.scatter( x, y, c=z, norm=self.norm, cmap=self.cmap, edgecolors="y" ) - def update(self, epoch): self.scatter.remove() self.plot(epoch) self.update_active(self.cur_id) - def update_active(self, _id): self.cur_id = _id self.active_scatter.remove() @@ -381,7 +347,6 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta } self.columns = columns self.geopandas_flag = geopandas_flag - self.train_df = train_df self.test_df = test_df self.sat_df = sat_df @@ -395,10 +360,8 @@ def __init__(self, train_df, test_df, sat_df=None, geopandas_flag=True, test_sta self.test_start = test_start else: self.test_start = self.min_time - self.unique_epochs = np.unique(self.train_df[columns["epoch"]]) self.unique_ids = np.unique(self.train_df[columns["id"]]) - self.start_epoch = self.unique_epochs[-1] self.start_id = self.unique_ids[0] @@ -411,7 +374,6 @@ def update_epoch(self, epoch): if self.grid_plot_flag: self.val_grid_plot.update(epoch) self.var_grid_plot.update(epoch) - self.time_series_plot.update_cur_epoch(epoch) self.val_scatter_plot.update(epoch) @@ -432,7 +394,6 @@ def show(self): self.scale_slider_ax = self.fig.add_subplot( self.gs[11, 1:3] ) # first row, second col - if self.grid_plot_flag: self.val_grid_plot = ST_GridPlot( self.columns, @@ -447,7 +408,6 @@ def show(self): geopandas_flag=self.geopandas_flag, ) self.val_grid_plot.setup() - self.var_grid_plot = ST_GridPlot( self.columns, "var", @@ -464,7 +424,6 @@ def show(self): else: self.val_grid_plot = None self.var_grid_plot = None - self.val_scatter_plot = ST_ScatterPlot( self.columns, self.fig, @@ -475,12 +434,10 @@ def show(self): self.train_df, ) self.val_scatter_plot.setup() - self.slider_plot = ST_SliderPlot( self.fig, self.epoch_slider_ax, self.unique_epochs, self.update_epoch ) self.slider_plot.setup(self.start_epoch) - self.time_series_plot = ST_TimeSeriesPlot( self.columns, self.fig, @@ -492,21 +449,16 @@ def show(self): ) self.time_series_plot.setup() + if self.sat_df is not None: + self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) if self.grid_plot_flag: self.val_grid_plot.plot(self.start_epoch) self.var_grid_plot.plot(self.start_epoch) - self.val_scatter_plot.plot(self.start_epoch) - self.time_series_plot.plot_cur_epoch(self.start_epoch) self.time_series_plot.plot(self.start_id) - self.val_scatter_plot.plot_active(self.start_id) - if self.sat_df is not None: - self.time_series_plot.ax.scatter(self.sat_df['epoch'], self.sat_df[self.columns['observed']], alpha=0.4) - plt.show() -