From 2a51b794aa92c99dcec0a9aa8050a6103f2d9f08 Mon Sep 17 00:00:00 2001 From: Sonhaf-prio Date: Fri, 23 Jan 2026 13:45:30 +0100 Subject: [PATCH 1/2] changed resolution of country shapefiles to 1:10 --- .../ne_10m_admin_0_countries.cpg | 3 + .../ne_10m_admin_0_countries.dbf | 3 + .../ne_10m_admin_0_countries.prj | 3 + .../ne_10m_admin_0_countries.shp | 3 + .../ne_10m_admin_0_countries.shx | 3 + views_postprocessing/unfao/mapping/mapping.py | 1369 +++++++++++------ 6 files changed, 875 insertions(+), 509 deletions(-) create mode 100644 views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.cpg create mode 100644 views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.dbf create mode 100644 views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.prj create mode 100644 views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shp create mode 100644 views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shx diff --git a/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.cpg b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.cpg new file mode 100644 index 0000000..105bf61 --- /dev/null +++ b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.cpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ad3031f5503a4404af825262ee8232cc04d4ea6683d42c5dd0a2f2a27ac9824 +size 5 diff --git a/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.dbf b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.dbf new file mode 100644 index 0000000..b36e8ca --- /dev/null +++ b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.dbf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5dbd3dd5fd7e2ef49051fc88562c03819e8ea63a382642df6eadd1243bf4b49 +size 878482 diff --git a/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.prj b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.prj new file mode 100644 index 0000000..2098c4d --- /dev/null +++ b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.prj @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a02a27b1d1982c8516d83398e85a3c8b1aef1713c13ef4d84d7bde17430c07c4 +size 145 diff --git a/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shp b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shp new file mode 100644 index 0000000..29356e6 --- /dev/null +++ b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ce119ef6342e43cff7c0c3004e0911ab7ec1988a14734372031d2012180e7bc +size 8806224 diff --git a/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shx b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shx new file mode 100644 index 0000000..1fef3c4 --- /dev/null +++ b/views_postprocessing/shapefiles/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca19ec112d054c77bc8f7ac00e3b110d5dff32cc9bcf4cd1b8b66bdd0f611d32 +size 2164 diff --git a/views_postprocessing/unfao/mapping/mapping.py b/views_postprocessing/unfao/mapping/mapping.py index c201929..8f93f6e 100644 --- a/views_postprocessing/unfao/mapping/mapping.py +++ b/views_postprocessing/unfao/mapping/mapping.py @@ -23,13 +23,33 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -warnings.filterwarnings('ignore') - -NATURAL_EARTH_COUNTRY_PATH = Path(__file__).parent.parent.parent / "shapefiles" / "ne_110m_admin_0_countries" / "ne_110m_admin_0_countries.shp" - -PRIOGRID_SHAPEFILE_PATH = Path(__file__).parent.parent.parent / "shapefiles" / "priogrid_cellshp" / "priogrid_cell.shp" -ADM_1_SHAPEFILE_PATH = Path(__file__).parent.parent.parent / "shapefiles" / "GAUL_2024_L1" / "GAUL_2024_L1.shp" -ADM_2_SHAPEFILE_PATH = Path(__file__).parent.parent.parent / "shapefiles" / "GAUL_2024_L2" / "GAUL_2024_L2.shp" +warnings.filterwarnings("ignore") + +NATURAL_EARTH_COUNTRY_PATH = ( + Path(__file__).parent.parent.parent + / "shapefiles" + / "ne_10m_admin_0_countries" + / "ne_10m_admin_0_countries.shp" +) + +PRIOGRID_SHAPEFILE_PATH = ( + Path(__file__).parent.parent.parent + / "shapefiles" + / "priogrid_cellshp" + / "priogrid_cell.shp" +) +ADM_1_SHAPEFILE_PATH = ( + Path(__file__).parent.parent.parent + / "shapefiles" + / "GAUL_2024_L1" + / "GAUL_2024_L1.shp" +) +ADM_2_SHAPEFILE_PATH = ( + Path(__file__).parent.parent.parent + / "shapefiles" + / "GAUL_2024_L2" + / "GAUL_2024_L2.shp" +) # Configure disk cache CACHE_DIR = Path.home() / ".priogrid_mapper_cache" @@ -124,7 +144,7 @@ class PriogridCountryMapper: def __init__(self, use_disk_cache=True, cache_dir=None, cache_ttl=None): """ Initialize the PriogridCountryMapper with optional disk caching. - + Parameters: use_disk_cache (bool): Whether to use persistent disk caching cache_dir (str or Path): Custom cache directory path @@ -139,7 +159,7 @@ def __init__(self, use_disk_cache=True, cache_dir=None, cache_ttl=None): # NOTE: This is the line that was crashing in your traceback (line 134) country_path = str(NATURAL_EARTH_COUNTRY_PATH) self.countries_gdf = self._load_and_preprocess_naturalearth(country_path) - + # Load PRIO-GRID data self.priogrid_gdf = self._load_priogrid(str(PRIOGRID_SHAPEFILE_PATH)) @@ -150,7 +170,7 @@ def __init__(self, use_disk_cache=True, cache_dir=None, cache_ttl=None): # Configure caching self.use_disk_cache = use_disk_cache self.cache_ttl = cache_ttl - + if use_disk_cache: # Set up custom cache directory if provided if cache_dir: @@ -158,38 +178,54 @@ def __init__(self, use_disk_cache=True, cache_dir=None, cache_ttl=None): os.makedirs(self.cache_dir, exist_ok=True) else: self.cache_dir = CACHE_DIR - + # Initialize disk caches - self._disk_country_cache = Memory(location=self.cache_dir / "country", verbose=0) - self._disk_admin1_cache = Memory(location=self.cache_dir / "admin1", verbose=0) - self._disk_admin2_cache = Memory(location=self.cache_dir / "admin2", verbose=0) + self._disk_country_cache = Memory( + location=self.cache_dir / "country", verbose=0 + ) + self._disk_admin1_cache = Memory( + location=self.cache_dir / "admin1", verbose=0 + ) + self._disk_admin2_cache = Memory( + location=self.cache_dir / "admin2", verbose=0 + ) self._disk_gid_cache = Memory(location=self.cache_dir / "gid", verbose=0) - + logger.info(f"Using disk cache at: {self.cache_dir}") else: # Initialize in-memory caches as instance variables if cache_ttl: - self._country_cache = TTLCache(maxsize=COUNTRY_CACHE_MAXSIZE, ttl=cache_ttl) + self._country_cache = TTLCache( + maxsize=COUNTRY_CACHE_MAXSIZE, ttl=cache_ttl + ) self._gid_cache = TTLCache(maxsize=GID_CACHE_MAXSIZE, ttl=cache_ttl) - self._gids_for_country_cache = TTLCache(maxsize=GIDS_FOR_COUNTRY_CACHE_MAXSIZE, ttl=cache_ttl) - self._admin1_cache = TTLCache(maxsize=ADMIN1_CACHE_MAXSIZE, ttl=cache_ttl) - self._admin2_cache = TTLCache(maxsize=ADMIN2_CACHE_MAXSIZE, ttl=cache_ttl) + self._gids_for_country_cache = TTLCache( + maxsize=GIDS_FOR_COUNTRY_CACHE_MAXSIZE, ttl=cache_ttl + ) + self._admin1_cache = TTLCache( + maxsize=ADMIN1_CACHE_MAXSIZE, ttl=cache_ttl + ) + self._admin2_cache = TTLCache( + maxsize=ADMIN2_CACHE_MAXSIZE, ttl=cache_ttl + ) else: self._country_cache = LRUCache(maxsize=COUNTRY_CACHE_MAXSIZE) self._gid_cache = LRUCache(maxsize=GID_CACHE_MAXSIZE) - self._gids_for_country_cache = LRUCache(maxsize=GIDS_FOR_COUNTRY_CACHE_MAXSIZE) + self._gids_for_country_cache = LRUCache( + maxsize=GIDS_FOR_COUNTRY_CACHE_MAXSIZE + ) self._admin1_cache = LRUCache(maxsize=ADMIN1_CACHE_MAXSIZE) self._admin2_cache = LRUCache(maxsize=ADMIN2_CACHE_MAXSIZE) - + logger.info("Using in-memory caching") # Load admin1 and admin2 data if paths are provided self.admin1_path = str(ADM_1_SHAPEFILE_PATH) self.admin2_path = str(ADM_2_SHAPEFILE_PATH) - + self.admin1_gdf = None self.admin2_gdf = None - + if self.admin1_path: self.admin1_gdf = self._load_admin_data(self.admin1_path, "admin1") self.admin1_sindex = ( @@ -197,7 +233,7 @@ def __init__(self, use_disk_cache=True, cache_dir=None, cache_ttl=None): ) else: self.admin1_sindex = None - + if self.admin2_path: self.admin2_gdf = self._load_admin_data(self.admin2_path, "admin2") self.admin2_sindex = ( @@ -225,7 +261,7 @@ def _get_admin_cache_key(self, gid, admin_level): def clear_cache(self, cache_type="all"): """ Clear cached data. - + Parameters: cache_type (str): Type of cache to clear ("all", "country", "admin1", "admin2", "gid") """ @@ -233,15 +269,15 @@ def clear_cache(self, cache_type="all"): if cache_type == "all" or cache_type == "country": self._disk_country_cache.clear() logger.info("Cleared country cache") - + if cache_type == "all" or cache_type == "admin1": self._disk_admin1_cache.clear() logger.info("Cleared admin1 cache") - + if cache_type == "all" or cache_type == "admin2": self._disk_admin2_cache.clear() logger.info("Cleared admin2 cache") - + if cache_type == "all" or cache_type == "gid": self._disk_gid_cache.clear() logger.info("Cleared GID cache") @@ -249,15 +285,15 @@ def clear_cache(self, cache_type="all"): if cache_type == "all" or cache_type == "country": self._country_cache.clear() logger.info("Cleared country cache") - + if cache_type == "all" or cache_type == "admin1": self._admin1_cache.clear() logger.info("Cleared admin1 cache") - + if cache_type == "all" or cache_type == "admin2": self._admin2_cache.clear() logger.info("Cleared admin2 cache") - + if cache_type == "all" or cache_type == "gid": self._gid_cache.clear() logger.info("Cleared GID cache") @@ -265,7 +301,7 @@ def clear_cache(self, cache_type="all"): def get_cache_stats(self): """ Get statistics about cache usage. - + Returns: dict: Cache statistics """ @@ -278,15 +314,17 @@ def get_cache_stats(self): "admin2_cache_size": len(self._disk_admin2_cache), "gid_cache_size": len(self._disk_gid_cache), } - + # Calculate total cache size on disk total_size = 0 - for cache_dir in [self.cache_dir / d for d in ["country", "admin1", "admin2", "gid"]]: + for cache_dir in [ + self.cache_dir / d for d in ["country", "admin1", "admin2", "gid"] + ]: if cache_dir.exists(): for file in cache_dir.glob("*"): if file.is_file(): total_size += file.stat().st_size - + stats["total_disk_size_mb"] = round(total_size / (1024 * 1024), 2) else: stats = { @@ -300,17 +338,21 @@ def get_cache_stats(self): "admin2_cache_size": len(self._admin2_cache), "admin2_cache_maxsize": self._admin2_cache.maxsize, } - - if hasattr(self._country_cache, 'currsize'): - stats["country_cache_hit_rate"] = f"{self._country_cache.hits}/{self._country_cache.hits + self._country_cache.misses}" - stats["gid_cache_hit_rate"] = f"{self._gid_cache.hits}/{self._gid_cache.hits + self._gid_cache.misses}" - + + if hasattr(self._country_cache, "currsize"): + stats["country_cache_hit_rate"] = ( + f"{self._country_cache.hits}/{self._country_cache.hits + self._country_cache.misses}" + ) + stats["gid_cache_hit_rate"] = ( + f"{self._gid_cache.hits}/{self._gid_cache.hits + self._gid_cache.misses}" + ) + return stats def warm_cache(self, gid_list=None, iso_a3_list=None): """ Warm up the cache with frequently accessed data. - + Parameters: gid_list (list): List of GIDs to pre-cache iso_a3_list (list): List of ISO A3 codes to pre-cache @@ -321,20 +363,20 @@ def warm_cache(self, gid_list=None, iso_a3_list=None): self.find_country_for_gid(gid) self.find_admin1_for_gid(gid) self.find_admin2_for_gid(gid) - + if iso_a3_list: logger.info(f"Warming cache with {len(iso_a3_list)} countries") for iso_a3 in iso_a3_list: self.find_country_by_iso_a3(iso_a3) self.find_gids_for_country(iso_a3) - + logger.info("Cache warming complete") def __del__(self): if self._process_pool: self._process_pool.close() self._process_pool.join() - + def _init_process_pool(self): """Initialize the process pool if not already done""" if self._process_pool is None: @@ -343,29 +385,31 @@ def _init_process_pool(self): def batch_country_mapping_parallel(self, gid_list, batch_size=1000): """ Find countries for multiple PRIO-GRID cells using multiprocessing. - + Parameters: gid_list (list): List of PRIO-GRID cell IDs batch_size (int): Number of GIDs to process in each batch - + Returns: DataFrame: Results of the batch mapping """ self._init_process_pool() - + # Split into batches - batches = [gid_list[i:i + batch_size] for i in range(0, len(gid_list), batch_size)] - + batches = [ + gid_list[i : i + batch_size] for i in range(0, len(gid_list), batch_size) + ] + # Create partial function for mapping map_func = partial(self._process_gid_batch) - + # Process batches in parallel results = [] for batch_result in self._process_pool.imap_unordered(map_func, batches): results.extend(batch_result) - + return pd.DataFrame(results) - + def _process_gid_batch(self, gid_batch): """Process a batch of GIDs in a single process""" batch_results = [] @@ -521,7 +565,7 @@ def _find_gid_for_point_impl(point_geometry): break return result - + return _find_gid_for_point_impl(point_geometry) else: # Use in-memory cache @@ -561,7 +605,7 @@ def _find_gid_for_point_impl(point_geometry): # Update cache self._gid_cache[cache_key] = result return result - + def find_country_for_gid(self, gid): """ Find the country information for a PRIO-GRID cell using majority area-based method. @@ -580,12 +624,20 @@ def _find_country_for_gid_impl(gid): grid_centroid = grid_cell["centroid"].iloc[0] # Use spatial index to find potentially intersecting countries - if hasattr(self.countries_gdf, 'sindex'): - possible_countries_idx = list(self.countries_gdf.sindex.intersection(grid_geometry.bounds)) - candidate_countries = self.countries_gdf.iloc[possible_countries_idx] - candidate_countries = candidate_countries[candidate_countries.intersects(grid_geometry)] + if hasattr(self.countries_gdf, "sindex"): + possible_countries_idx = list( + self.countries_gdf.sindex.intersection(grid_geometry.bounds) + ) + candidate_countries = self.countries_gdf.iloc[ + possible_countries_idx + ] + candidate_countries = candidate_countries[ + candidate_countries.intersects(grid_geometry) + ] else: - candidate_countries = self.countries_gdf[self.countries_gdf.intersects(grid_geometry)] + candidate_countries = self.countries_gdf[ + self.countries_gdf.intersects(grid_geometry) + ] # Calculate area overlaps overlaps = [] @@ -596,11 +648,13 @@ def _find_country_for_gid_impl(gid): total_area = grid_geometry.area overlap_ratio = overlap_area / total_area - overlaps.append({ - "country_data": country, - "overlap_area": overlap_area, - "overlap_ratio": overlap_ratio, - }) + overlaps.append( + { + "country_data": country, + "overlap_area": overlap_area, + "overlap_ratio": overlap_ratio, + } + ) except Exception as e: logger.debug(f"Overlap calculation error for GID {gid}: {e}") continue @@ -620,7 +674,9 @@ def _find_country_for_gid_impl(gid): "gid": int(gid), "iso_a3": country["ISO_A3"], "country_name": country["NAME_EN"], - "overlap_ratio": float(overlaps[0]["overlap_ratio"]) if overlaps else 0.0, + "overlap_ratio": ( + float(overlaps[0]["overlap_ratio"]) if overlaps else 0.0 + ), "method": method_used, } @@ -630,13 +686,13 @@ def _find_country_for_gid_impl(gid): result[col] = country[col] return result - + return _find_country_for_gid_impl(gid) else: # Use in-memory cache # Generate consistent cache key cache_key = self._get_gid_cache_key(gid) - + # Check cache first if cache_key in self._country_cache: cached_value = self._country_cache[cache_key] @@ -655,12 +711,18 @@ def _find_country_for_gid_impl(gid): grid_centroid = grid_cell["centroid"].iloc[0] # Use spatial index to find potentially intersecting countries - if hasattr(self.countries_gdf, 'sindex'): - possible_countries_idx = list(self.countries_gdf.sindex.intersection(grid_geometry.bounds)) + if hasattr(self.countries_gdf, "sindex"): + possible_countries_idx = list( + self.countries_gdf.sindex.intersection(grid_geometry.bounds) + ) candidate_countries = self.countries_gdf.iloc[possible_countries_idx] - candidate_countries = candidate_countries[candidate_countries.intersects(grid_geometry)] + candidate_countries = candidate_countries[ + candidate_countries.intersects(grid_geometry) + ] else: - candidate_countries = self.countries_gdf[self.countries_gdf.intersects(grid_geometry)] + candidate_countries = self.countries_gdf[ + self.countries_gdf.intersects(grid_geometry) + ] # Calculate area overlaps overlaps = [] @@ -671,11 +733,13 @@ def _find_country_for_gid_impl(gid): total_area = grid_geometry.area overlap_ratio = overlap_area / total_area - overlaps.append({ - "country_data": country, - "overlap_area": overlap_area, - "overlap_ratio": overlap_ratio, - }) + overlaps.append( + { + "country_data": country, + "overlap_area": overlap_area, + "overlap_ratio": overlap_ratio, + } + ) except Exception as e: logger.debug(f"Overlap calculation error for GID {gid}: {e}") continue @@ -695,7 +759,9 @@ def _find_country_for_gid_impl(gid): "gid": int(gid), "iso_a3": country["ISO_A3"], "country_name": country["NAME_EN"], - "overlap_ratio": float(overlaps[0]["overlap_ratio"]) if overlaps else 0.0, + "overlap_ratio": ( + float(overlaps[0]["overlap_ratio"]) if overlaps else 0.0 + ), "method": method_used, } @@ -714,7 +780,7 @@ def batch_country_mapping(self, gid_list): """ results = [] cache_hits = 0 - + for gid in gid_list: # Check cache first cache_key = self._get_gid_cache_key(gid) @@ -726,14 +792,16 @@ def batch_country_mapping(self, gid_list): results.append(None) cache_hits += 1 continue - + # Not in cache, process this GID result = self.find_country_for_gid(gid) results.append(result) - + if cache_hits > 0: - logger.info(f"Cache hits: {cache_hits}/{len(gid_list)} ({cache_hits/len(gid_list)*100:.1f}%)") - + logger.info( + f"Cache hits: {cache_hits}/{len(gid_list)} ({cache_hits/len(gid_list)*100:.1f}%)" + ) + return pd.DataFrame([r for r in results if r is not None]) def find_gids_for_country(self, iso_a3): @@ -753,14 +821,22 @@ def _find_gids_for_country_impl(iso_a3): else: # Get the country geometry country_geometry = country.iloc[0]["geometry"] - + # OPTIMIZATION: Use spatial index to quickly find intersecting grid cells if self.priogrid_sindex: - possible_matches_index = list(self.priogrid_sindex.intersection(country_geometry.bounds)) - candidate_grid_cells = self.priogrid_gdf.iloc[possible_matches_index] - intersecting_cells = candidate_grid_cells[candidate_grid_cells.intersects(country_geometry)] + possible_matches_index = list( + self.priogrid_sindex.intersection(country_geometry.bounds) + ) + candidate_grid_cells = self.priogrid_gdf.iloc[ + possible_matches_index + ] + intersecting_cells = candidate_grid_cells[ + candidate_grid_cells.intersects(country_geometry) + ] else: - intersecting_cells = self.priogrid_gdf[self.priogrid_gdf.intersects(country_geometry)] + intersecting_cells = self.priogrid_gdf[ + self.priogrid_gdf.intersects(country_geometry) + ] if len(intersecting_cells) == 0: result = [] @@ -772,7 +848,7 @@ def _find_gids_for_country_impl(iso_a3): result = matching_gids return result - + return _find_gids_for_country_impl(iso_a3) else: # Use in-memory cache @@ -793,14 +869,22 @@ def _find_gids_for_country_impl(iso_a3): else: # Get the country geometry country_geometry = country.iloc[0]["geometry"] - + # OPTIMIZATION: Use spatial index to quickly find intersecting grid cells if self.priogrid_sindex: - possible_matches_index = list(self.priogrid_sindex.intersection(country_geometry.bounds)) - candidate_grid_cells = self.priogrid_gdf.iloc[possible_matches_index] - intersecting_cells = candidate_grid_cells[candidate_grid_cells.intersects(country_geometry)] + possible_matches_index = list( + self.priogrid_sindex.intersection(country_geometry.bounds) + ) + candidate_grid_cells = self.priogrid_gdf.iloc[ + possible_matches_index + ] + intersecting_cells = candidate_grid_cells[ + candidate_grid_cells.intersects(country_geometry) + ] else: - intersecting_cells = self.priogrid_gdf[self.priogrid_gdf.intersects(country_geometry)] + intersecting_cells = self.priogrid_gdf[ + self.priogrid_gdf.intersects(country_geometry) + ] if len(intersecting_cells) == 0: result = [] @@ -814,25 +898,25 @@ def _find_gids_for_country_impl(iso_a3): # Update cache self._gids_for_country_cache[cache_key] = result return result - + def _find_dominant_country_gids(self, intersecting_cells, country_geometry, iso_a3): """Find grid cells where target country has dominant area using optimized approach.""" matching_gids = [] - + for _, grid_cell in intersecting_cells.iterrows(): grid_geometry = grid_cell["geometry"] grid_gid = grid_cell["gid"] - + # OPTIMIZATION 1: Early check for complete containment if country_geometry.contains(grid_geometry): matching_gids.append(int(grid_gid)) continue - + # Calculate overlap with the target country try: intersection = country_geometry.intersection(grid_geometry) overlap_ratio = intersection.area / grid_geometry.area - + # If overlap is significant, include it if overlap_ratio > 0.5: matching_gids.append(int(grid_gid)) @@ -840,7 +924,7 @@ def _find_dominant_country_gids(self, intersecting_cells, country_geometry, iso_ # Handle potential geometry errors gracefully logger.debug(f"Geometry error for GID {grid_gid}: {e}") continue - + return matching_gids def get_all_iso_a3_codes(self): @@ -878,10 +962,12 @@ def visualize_grid_and_country(self, gid): # Find which country contains the grid cell based on majority area country_info = self.find_country_for_gid(gid) containing_country = None - + if country_info: iso_a3 = country_info["iso_a3"] - containing_country = self.countries_gdf[self.countries_gdf["ISO_A3"] == iso_a3].iloc[0] + containing_country = self.countries_gdf[ + self.countries_gdf["ISO_A3"] == iso_a3 + ].iloc[0] # Create plot fig, ax = plt.subplots(1, 1, figsize=(10, 8)) @@ -918,10 +1004,10 @@ def visualize_grid_and_country(self, gid): ) method = country_info.get("method", "unknown") if country_info else "unknown" overlap_ratio = country_info.get("overlap_ratio", 0) if country_info else 0 - + ax.set_title( - f'PRIO-GRID Cell {gid} in {country_name}\n' - f'Method: {method} | Overlap Ratio: {overlap_ratio:.3f}' + f"PRIO-GRID Cell {gid} in {country_name}\n" + f"Method: {method} | Overlap Ratio: {overlap_ratio:.3f}" ) ax.legend() @@ -931,18 +1017,18 @@ def visualize_grid_and_country(self, gid): def calculate_capital_distance(self, grid_lat, grid_lon, capital_coords): """ Calculate distance from grid cell to capital. - + Parameters: grid_lat: Latitude of grid centroid grid_lon: Longitude of grid centroid capital_coords: Tuple of (capital_lon, capital_lat) or None - + Returns: float or None: Distance in kilometers or None if calculation fails """ if capital_coords is None or None in capital_coords: return None - + try: capital_lon, capital_lat = capital_coords return cached_haversine(grid_lat, grid_lon, capital_lat, capital_lon) @@ -964,7 +1050,7 @@ def get_gid_from_point(self, point: ShapelyPoint) -> Optional[int]: if grid_cell["geometry"].contains(point): return grid_cell["gid"] return None - + def get_point_from_gid(self, gid: int) -> Optional["Point"]: """ Get the centroid point of a PRIO-GRID cell. @@ -981,7 +1067,7 @@ def get_point_from_gid(self, gid: int) -> Optional["Point"]: return None centroid = grid_cell["centroid"].iloc[0] return Point(lat=centroid.y, lon=centroid.x) - + def _load_admin_data(self, admin_path, admin_level): """ Load and preprocess admin1 or admin2 data from the provided path. @@ -1025,7 +1111,7 @@ def _validate_admin_data(self, admin_gdf, admin_level): required_columns = ["gaul1_code", "gaul1_name", "iso3_code", "geometry"] else: # admin2 required_columns = ["gaul2_code", "gaul2_name", "iso3_code", "geometry"] - + missing_columns = [ col for col in required_columns if col not in admin_gdf.columns ] @@ -1059,9 +1145,9 @@ def _find_admin1_for_gid_impl(gid): country_info = self.find_country_for_gid(gid) if not country_info: return None - + iso_a3 = country_info["iso_a3"] - + # Get the PRIO-GRID cell grid_cell = self.priogrid_gdf[self.priogrid_gdf["gid"] == gid] if len(grid_cell) == 0: @@ -1071,19 +1157,27 @@ def _find_admin1_for_gid_impl(gid): # Filter admin1 regions to only those in the same country country_admin1 = self.admin1_gdf[self.admin1_gdf["iso3_code"] == iso_a3] - + if len(country_admin1) == 0: return None # Use spatial index to find potentially intersecting admin1 regions if self.admin1_sindex: - possible_admin1_idx = list(self.admin1_sindex.intersection(grid_geometry.bounds)) + possible_admin1_idx = list( + self.admin1_sindex.intersection(grid_geometry.bounds) + ) candidate_admin1 = self.admin1_gdf.iloc[possible_admin1_idx] - candidate_admin1 = candidate_admin1[candidate_admin1.intersects(grid_geometry)] + candidate_admin1 = candidate_admin1[ + candidate_admin1.intersects(grid_geometry) + ] # Filter to only those in the same country - candidate_admin1 = candidate_admin1[candidate_admin1["iso3_code"] == iso_a3] + candidate_admin1 = candidate_admin1[ + candidate_admin1["iso3_code"] == iso_a3 + ] else: - candidate_admin1 = country_admin1[country_admin1.intersects(grid_geometry)] + candidate_admin1 = country_admin1[ + country_admin1.intersects(grid_geometry) + ] if len(candidate_admin1) == 0: return None @@ -1097,28 +1191,34 @@ def _find_admin1_for_gid_impl(gid): total_area = grid_geometry.area overlap_ratio = overlap_area / total_area - overlaps.append({ - "admin1_data": admin1, - "overlap_area": overlap_area, - "overlap_ratio": overlap_ratio, - }) + overlaps.append( + { + "admin1_data": admin1, + "overlap_area": overlap_area, + "overlap_ratio": overlap_ratio, + } + ) except Exception as e: logger.debug(f"Overlap calculation error for GID {gid}: {e}") continue # Sort by overlap ratio (descending) overlaps.sort(key=lambda x: x["overlap_ratio"], reverse=True) - + if not overlaps: return None - + admin1 = overlaps[0]["admin1_data"] method_used = "largest overlap" # Prepare result result = { "gid": int(gid), - "gaul1_code": int(admin1["gaul1_code"]) if admin1["gaul1_code"] is not None else None, + "gaul1_code": ( + int(admin1["gaul1_code"]) + if admin1["gaul1_code"] is not None + else None + ), "gaul1_name": admin1["gaul1_name"], "iso3_code": admin1["iso3_code"], "method": method_used, @@ -1126,7 +1226,11 @@ def _find_admin1_for_gid_impl(gid): # Add additional fields if available if "gaul0_code" in admin1: - result["gaul0_code"] = int(admin1["gaul0_code"]) if admin1["gaul0_code"] is not None else None + result["gaul0_code"] = ( + int(admin1["gaul0_code"]) + if admin1["gaul0_code"] is not None + else None + ) if "gaul0_name" in admin1: result["gaul0_name"] = admin1["gaul0_name"] if "continent" in admin1: @@ -1135,13 +1239,13 @@ def _find_admin1_for_gid_impl(gid): result["disp_en"] = admin1["disp_en"] return result - + return _find_admin1_for_gid_impl(gid) else: # Use in-memory cache # Generate consistent cache key cache_key = self._get_admin_cache_key(gid, "admin1") - + # Check cache first if cache_key in self._admin1_cache: cached_value = self._admin1_cache[cache_key] @@ -1156,9 +1260,9 @@ def _find_admin1_for_gid_impl(gid): if not country_info: self._admin1_cache[cache_key] = None return None - + iso_a3 = country_info["iso_a3"] - + # Get the PRIO-GRID cell grid_cell = self.priogrid_gdf[self.priogrid_gdf["gid"] == gid] if len(grid_cell) == 0: @@ -1169,20 +1273,28 @@ def _find_admin1_for_gid_impl(gid): # Filter admin1 regions to only those in the same country country_admin1 = self.admin1_gdf[self.admin1_gdf["iso3_code"] == iso_a3] - + if len(country_admin1) == 0: self._admin1_cache[cache_key] = None return None # Use spatial index to find potentially intersecting admin1 regions if self.admin1_sindex: - possible_admin1_idx = list(self.admin1_sindex.intersection(grid_geometry.bounds)) + possible_admin1_idx = list( + self.admin1_sindex.intersection(grid_geometry.bounds) + ) candidate_admin1 = self.admin1_gdf.iloc[possible_admin1_idx] - candidate_admin1 = candidate_admin1[candidate_admin1.intersects(grid_geometry)] + candidate_admin1 = candidate_admin1[ + candidate_admin1.intersects(grid_geometry) + ] # Filter to only those in the same country - candidate_admin1 = candidate_admin1[candidate_admin1["iso3_code"] == iso_a3] + candidate_admin1 = candidate_admin1[ + candidate_admin1["iso3_code"] == iso_a3 + ] else: - candidate_admin1 = country_admin1[country_admin1.intersects(grid_geometry)] + candidate_admin1 = country_admin1[ + country_admin1.intersects(grid_geometry) + ] if len(candidate_admin1) == 0: self._admin1_cache[cache_key] = None @@ -1197,29 +1309,35 @@ def _find_admin1_for_gid_impl(gid): total_area = grid_geometry.area overlap_ratio = overlap_area / total_area - overlaps.append({ - "admin1_data": admin1, - "overlap_area": overlap_area, - "overlap_ratio": overlap_ratio, - }) + overlaps.append( + { + "admin1_data": admin1, + "overlap_area": overlap_area, + "overlap_ratio": overlap_ratio, + } + ) except Exception as e: logger.debug(f"Overlap calculation error for GID {gid}: {e}") continue # Sort by overlap ratio (descending) overlaps.sort(key=lambda x: x["overlap_ratio"], reverse=True) - + if not overlaps: self._admin1_cache[cache_key] = None return None - + admin1 = overlaps[0]["admin1_data"] method_used = "largest overlap" # Prepare result result = { "gid": int(gid), - "gaul1_code": int(admin1["gaul1_code"]) if admin1["gaul1_code"] is not None else None, + "gaul1_code": ( + int(admin1["gaul1_code"]) + if admin1["gaul1_code"] is not None + else None + ), "gaul1_name": admin1["gaul1_name"], "iso3_code": admin1["iso3_code"], "method": method_used, @@ -1227,7 +1345,11 @@ def _find_admin1_for_gid_impl(gid): # Add additional fields if available if "gaul0_code" in admin1: - result["gaul0_code"] = int(admin1["gaul0_code"]) if admin1["gaul0_code"] is not None else None + result["gaul0_code"] = ( + int(admin1["gaul0_code"]) + if admin1["gaul0_code"] is not None + else None + ) if "gaul0_name" in admin1: result["gaul0_name"] = admin1["gaul0_name"] if "continent" in admin1: @@ -1257,9 +1379,9 @@ def _find_admin2_for_gid_impl(gid): country_info = self.find_country_for_gid(gid) if not country_info: return None - + iso_a3 = country_info["iso_a3"] - + # Get the PRIO-GRID cell grid_cell = self.priogrid_gdf[self.priogrid_gdf["gid"] == gid] if len(grid_cell) == 0: @@ -1269,19 +1391,27 @@ def _find_admin2_for_gid_impl(gid): # Filter admin2 regions to only those in the same country country_admin2 = self.admin2_gdf[self.admin2_gdf["iso3_code"] == iso_a3] - + if len(country_admin2) == 0: return None # Use spatial index to find potentially intersecting admin2 regions if self.admin2_sindex: - possible_admin2_idx = list(self.admin2_sindex.intersection(grid_geometry.bounds)) + possible_admin2_idx = list( + self.admin2_sindex.intersection(grid_geometry.bounds) + ) candidate_admin2 = self.admin2_gdf.iloc[possible_admin2_idx] - candidate_admin2 = candidate_admin2[candidate_admin2.intersects(grid_geometry)] + candidate_admin2 = candidate_admin2[ + candidate_admin2.intersects(grid_geometry) + ] # Filter to only those in the same country - candidate_admin2 = candidate_admin2[candidate_admin2["iso3_code"] == iso_a3] + candidate_admin2 = candidate_admin2[ + candidate_admin2["iso3_code"] == iso_a3 + ] else: - candidate_admin2 = country_admin2[country_admin2.intersects(grid_geometry)] + candidate_admin2 = country_admin2[ + country_admin2.intersects(grid_geometry) + ] if len(candidate_admin2) == 0: return None @@ -1295,28 +1425,34 @@ def _find_admin2_for_gid_impl(gid): total_area = grid_geometry.area overlap_ratio = overlap_area / total_area - overlaps.append({ - "admin2_data": admin2, - "overlap_area": overlap_area, - "overlap_ratio": overlap_ratio, - }) + overlaps.append( + { + "admin2_data": admin2, + "overlap_area": overlap_area, + "overlap_ratio": overlap_ratio, + } + ) except Exception as e: logger.debug(f"Overlap calculation error for GID {gid}: {e}") continue # Sort by overlap ratio (descending) overlaps.sort(key=lambda x: x["overlap_ratio"], reverse=True) - + if not overlaps: return None - + admin2 = overlaps[0]["admin2_data"] method_used = "largest overlap" # Prepare result result = { "gid": int(gid), - "gaul2_code": int(admin2["gaul2_code"]) if admin2["gaul2_code"] is not None else None, + "gaul2_code": ( + int(admin2["gaul2_code"]) + if admin2["gaul2_code"] is not None + else None + ), "gaul2_name": admin2["gaul2_name"], "iso3_code": admin2["iso3_code"], "method": method_used, @@ -1324,11 +1460,19 @@ def _find_admin2_for_gid_impl(gid): # Add additional fields if available if "gaul0_code" in admin2: - result["gaul0_code"] = int(admin2["gaul0_code"]) if admin2["gaul0_code"] is not None else None + result["gaul0_code"] = ( + int(admin2["gaul0_code"]) + if admin2["gaul0_code"] is not None + else None + ) if "gaul0_name" in admin2: result["gaul0_name"] = admin2["gaul0_name"] if "gaul1_code" in admin2: - result["gaul1_code"] = int(admin2["gaul1_code"]) if admin2["gaul1_code"] is not None else None + result["gaul1_code"] = ( + int(admin2["gaul1_code"]) + if admin2["gaul1_code"] is not None + else None + ) if "gaul1_name" in admin2: result["gaul1_name"] = admin2["gaul1_name"] if "continent" in admin2: @@ -1337,13 +1481,13 @@ def _find_admin2_for_gid_impl(gid): result["disp_en"] = admin2["disp_en"] return result - + return _find_admin2_for_gid_impl(gid) else: # Use in-memory cache # Generate consistent cache key cache_key = self._get_admin_cache_key(gid, "admin2") - + # Check cache first if cache_key in self._admin2_cache: cached_value = self._admin2_cache[cache_key] @@ -1358,9 +1502,9 @@ def _find_admin2_for_gid_impl(gid): if not country_info: self._admin2_cache[cache_key] = None return None - + iso_a3 = country_info["iso_a3"] - + # Get the PRIO-GRID cell grid_cell = self.priogrid_gdf[self.priogrid_gdf["gid"] == gid] if len(grid_cell) == 0: @@ -1371,20 +1515,28 @@ def _find_admin2_for_gid_impl(gid): # Filter admin2 regions to only those in the same country country_admin2 = self.admin2_gdf[self.admin2_gdf["iso3_code"] == iso_a3] - + if len(country_admin2) == 0: self._admin2_cache[cache_key] = None return None # Use spatial index to find potentially intersecting admin2 regions if self.admin2_sindex: - possible_admin2_idx = list(self.admin2_sindex.intersection(grid_geometry.bounds)) + possible_admin2_idx = list( + self.admin2_sindex.intersection(grid_geometry.bounds) + ) candidate_admin2 = self.admin2_gdf.iloc[possible_admin2_idx] - candidate_admin2 = candidate_admin2[candidate_admin2.intersects(grid_geometry)] + candidate_admin2 = candidate_admin2[ + candidate_admin2.intersects(grid_geometry) + ] # Filter to only those in the same country - candidate_admin2 = candidate_admin2[candidate_admin2["iso3_code"] == iso_a3] + candidate_admin2 = candidate_admin2[ + candidate_admin2["iso3_code"] == iso_a3 + ] else: - candidate_admin2 = country_admin2[country_admin2.intersects(grid_geometry)] + candidate_admin2 = country_admin2[ + country_admin2.intersects(grid_geometry) + ] if len(candidate_admin2) == 0: self._admin2_cache[cache_key] = None @@ -1399,29 +1551,35 @@ def _find_admin2_for_gid_impl(gid): total_area = grid_geometry.area overlap_ratio = overlap_area / total_area - overlaps.append({ - "admin2_data": admin2, - "overlap_area": overlap_area, - "overlap_ratio": overlap_ratio, - }) + overlaps.append( + { + "admin2_data": admin2, + "overlap_area": overlap_area, + "overlap_ratio": overlap_ratio, + } + ) except Exception as e: logger.debug(f"Overlap calculation error for GID {gid}: {e}") continue # Sort by overlap ratio (descending) overlaps.sort(key=lambda x: x["overlap_ratio"], reverse=True) - + if not overlaps: self._admin2_cache[cache_key] = None return None - + admin2 = overlaps[0]["admin2_data"] method_used = "largest overlap" # Prepare result result = { "gid": int(gid), - "gaul2_code": int(admin2["gaul2_code"]) if admin2["gaul2_code"] is not None else None, + "gaul2_code": ( + int(admin2["gaul2_code"]) + if admin2["gaul2_code"] is not None + else None + ), "gaul2_name": admin2["gaul2_name"], "iso3_code": admin2["iso3_code"], "method": method_used, @@ -1429,11 +1587,19 @@ def _find_admin2_for_gid_impl(gid): # Add additional fields if available if "gaul0_code" in admin2: - result["gaul0_code"] = int(admin2["gaul0_code"]) if admin2["gaul0_code"] is not None else None + result["gaul0_code"] = ( + int(admin2["gaul0_code"]) + if admin2["gaul0_code"] is not None + else None + ) if "gaul0_name" in admin2: result["gaul0_name"] = admin2["gaul0_name"] if "gaul1_code" in admin2: - result["gaul1_code"] = int(admin2["gaul1_code"]) if admin2["gaul1_code"] is not None else None + result["gaul1_code"] = ( + int(admin2["gaul1_code"]) + if admin2["gaul1_code"] is not None + else None + ) if "gaul1_name" in admin2: result["gaul1_name"] = admin2["gaul1_name"] if "continent" in admin2: @@ -1456,22 +1622,22 @@ def find_all_admin_for_gid(self, gid): dict: Combined information for country, admin1, and admin2 """ result = {"gid": int(gid)} - + # Find country information country_info = self.find_country_for_gid(gid) if country_info: result["country"] = country_info - + # Find admin1 information admin1_info = self.find_admin1_for_gid(gid) if admin1_info: result["admin1"] = admin1_info - + # Find admin2 information admin2_info = self.find_admin2_for_gid(gid) if admin2_info: result["admin2"] = admin2_info - + return result def batch_admin_mapping(self, gid_list, admin_level="admin1"): @@ -1487,12 +1653,12 @@ def batch_admin_mapping(self, gid_list, admin_level="admin1"): """ if admin_level not in ["admin1", "admin2"]: raise ValueError("admin_level must be either 'admin1' or 'admin2'") - + if admin_level == "admin1" and self.admin1_gdf is None: raise ValueError("Admin1 data not loaded") if admin_level == "admin2" and self.admin2_gdf is None: raise ValueError("Admin2 data not loaded") - + # Select the appropriate find function if admin_level == "admin1": find_func = self.find_admin1_for_gid @@ -1500,10 +1666,10 @@ def batch_admin_mapping(self, gid_list, admin_level="admin1"): else: # admin2 find_func = self.find_admin2_for_gid cache = self._admin2_cache - + results = [] cache_hits = 0 - + for gid in gid_list: # Check cache first cache_key = self._get_admin_cache_key(gid, admin_level) @@ -1515,17 +1681,21 @@ def batch_admin_mapping(self, gid_list, admin_level="admin1"): results.append(None) cache_hits += 1 continue - + # Not in cache, process this GID result = find_func(gid) results.append(result) - + if cache_hits > 0: - logger.info(f"Cache hits: {cache_hits}/{len(gid_list)} ({cache_hits/len(gid_list)*100:.1f}%)") - + logger.info( + f"Cache hits: {cache_hits}/{len(gid_list)} ({cache_hits/len(gid_list)*100:.1f}%)" + ) + return pd.DataFrame([r for r in results if r is not None]) - def batch_admin_mapping_parallel(self, gid_list, admin_level="admin1", batch_size=1000): + def batch_admin_mapping_parallel( + self, gid_list, admin_level="admin1", batch_size=1000 + ): """ Batch mapping for admin1 or admin2 using multiprocessing. @@ -1539,33 +1709,35 @@ def batch_admin_mapping_parallel(self, gid_list, admin_level="admin1", batch_siz """ if admin_level not in ["admin1", "admin2"]: raise ValueError("admin_level must be either 'admin1' or 'admin2'") - + if admin_level == "admin1" and self.admin1_gdf is None: raise ValueError("Admin1 data not loaded") if admin_level == "admin2" and self.admin2_gdf is None: raise ValueError("Admin2 data not loaded") - + self._init_process_pool() - + # Split into batches - batches = [gid_list[i:i + batch_size] for i in range(0, len(gid_list), batch_size)] - + batches = [ + gid_list[i : i + batch_size] for i in range(0, len(gid_list), batch_size) + ] + # Select the appropriate find function if admin_level == "admin1": find_func = self.find_admin1_for_gid else: # admin2 find_func = self.find_admin2_for_gid - + # Create partial function for mapping map_func = partial(self._process_gid_batch_admin, find_func=find_func) - + # Process batches in parallel results = [] for batch_result in self._process_pool.imap_unordered(map_func, batches): results.extend(batch_result) - + return pd.DataFrame(results) - + def _process_gid_batch_admin(self, gid_batch, find_func): """Process a batch of GIDs in a single process for admin mapping""" batch_results = [] @@ -1588,7 +1760,7 @@ def find_all_admin_for_point(self, point_geometry): gid = self.find_gid_for_point(point_geometry) if gid is None: return None - + # Use the existing method to find all admin information return self.find_all_admin_for_gid(gid) @@ -1604,32 +1776,34 @@ def extend_find_country_for_gid(self, gid): """ # Get country information using the existing method result = self.find_country_for_gid(gid) - + if result is None: return None - + # Add admin1 information admin1_info = self.find_admin1_for_gid(gid) if admin1_info: result["admin1"] = admin1_info - + # Add admin2 information admin2_info = self.find_admin2_for_gid(gid) if admin2_info: result["admin2"] = admin2_info - + return result - - def visualize_grid_and_admin(self, gid, admin_level="admin1", show_all_admins=False): + + def visualize_grid_and_admin( + self, gid, admin_level="admin1", show_all_admins=False + ): """ Visualize a PRIO-GRID cell and its corresponding admin1 or admin2 boundary. - + Parameters: gid (int): PRIO-GRID cell ID admin_level (str): Either "admin1" or "admin2" - show_all_admins (bool): If True, shows all admin regions in the area, + show_all_admins (bool): If True, shows all admin regions in the area, otherwise only shows the one containing the grid cell - + Raises: ImportError: If matplotlib is not available ValueError: If the PRIO-GRID cell ID is not found or admin data is not loaded @@ -1640,24 +1814,24 @@ def visualize_grid_and_admin(self, gid, admin_level="admin1", show_all_admins=Fa except ImportError: logger.error("Matplotlib is required for visualization") return - + if admin_level not in ["admin1", "admin2"]: raise ValueError("admin_level must be either 'admin1' or 'admin2'") - + # Check if admin data is loaded if admin_level == "admin1" and self.admin1_gdf is None: raise ValueError("Admin1 data not loaded. Cannot visualize admin1.") if admin_level == "admin2" and self.admin2_gdf is None: raise ValueError("Admin2 data not loaded. Cannot visualize admin2.") - + # Get the grid cell grid_cell = self.priogrid_gdf[self.priogrid_gdf["gid"] == gid] if len(grid_cell) == 0: raise ValueError(f"PRIO-GRID cell with ID {gid} not found") - + grid_geometry = grid_cell["geometry"].iloc[0] grid_centroid = grid_cell["centroid"].iloc[0] - + # Select appropriate admin data if admin_level == "admin1": admin_gdf = self.admin1_gdf @@ -1669,42 +1843,64 @@ def visualize_grid_and_admin(self, gid, admin_level="admin1", show_all_admins=Fa admin_code_col = "gaul2_code" admin_name_col = "gaul2_name" find_func = self.find_admin2_for_gid - + # Find the admin region containing the grid cell admin_info = find_func(gid) - + # Create plot fig, ax = plt.subplots(1, 1, figsize=(12, 10)) - + if show_all_admins: # Show all admin regions that intersect with the grid cell's bounds if admin_level == "admin1" and self.admin1_sindex: - possible_admins_idx = list(self.admin1_sindex.intersection(grid_geometry.bounds)) + possible_admins_idx = list( + self.admin1_sindex.intersection(grid_geometry.bounds) + ) nearby_admins = self.admin1_gdf.iloc[possible_admins_idx] elif admin_level == "admin2" and self.admin2_sindex: - possible_admins_idx = list(self.admin2_sindex.intersection(grid_geometry.bounds)) + possible_admins_idx = list( + self.admin2_sindex.intersection(grid_geometry.bounds) + ) nearby_admins = self.admin2_gdf.iloc[possible_admins_idx] else: nearby_admins = admin_gdf[admin_gdf.intersects(grid_geometry.buffer(1))] - + # Plot all nearby admin regions - nearby_admins.plot(ax=ax, color="lightgray", edgecolor="white", alpha=0.5, linewidth=0.5) - + nearby_admins.plot( + ax=ax, color="lightgray", edgecolor="white", alpha=0.5, linewidth=0.5 + ) + # Highlight the containing admin region if found if admin_info: - containing_admin = admin_gdf[admin_gdf[admin_code_col] == admin_info[admin_code_col]] + containing_admin = admin_gdf[ + admin_gdf[admin_code_col] == admin_info[admin_code_col] + ] if len(containing_admin) > 0: - containing_admin.plot(ax=ax, color="lightblue", edgecolor="blue", alpha=0.7, linewidth=1.5) + containing_admin.plot( + ax=ax, + color="lightblue", + edgecolor="blue", + alpha=0.7, + linewidth=1.5, + ) else: # Only show the containing admin region if admin_info: - containing_admin = admin_gdf[admin_gdf[admin_code_col] == admin_info[admin_code_col]] + containing_admin = admin_gdf[ + admin_gdf[admin_code_col] == admin_info[admin_code_col] + ] if len(containing_admin) > 0: - containing_admin.plot(ax=ax, color="lightblue", edgecolor="blue", alpha=0.7, linewidth=1.5) - + containing_admin.plot( + ax=ax, + color="lightblue", + edgecolor="blue", + alpha=0.7, + linewidth=1.5, + ) + # Plot the grid cell grid_cell.plot(ax=ax, color="red", alpha=0.8, edgecolor="darkred", linewidth=2) - + # Add centroid marker ax.scatter( grid_centroid.x, @@ -1715,66 +1911,78 @@ def visualize_grid_and_admin(self, gid, admin_level="admin1", show_all_admins=Fa edgecolor="black", linewidth=1, label="Grid Centroid", - zorder=5 + zorder=5, ) - + # Set title and legend if admin_info: admin_name = admin_info[admin_name_col] admin_code = admin_info[admin_code_col] method = admin_info.get("method", "unknown") - - title = f'PRIO-GRID Cell {gid} in {admin_level.upper()} {admin_name} (Code: {admin_code})\n' - title += f'Method: {method}' + + title = f"PRIO-GRID Cell {gid} in {admin_level.upper()} {admin_name} (Code: {admin_code})\n" + title += f"Method: {method}" else: - title = f'PRIO-GRID Cell {gid} - No {admin_level.upper()} region found' - - ax.set_title(title, fontsize=14, fontweight='bold') - + title = f"PRIO-GRID Cell {gid} - No {admin_level.upper()} region found" + + ax.set_title(title, fontsize=14, fontweight="bold") + # Create legend legend_elements = [ - mpatches.Patch(color='red', alpha=0.8, label='PRIO-GRID Cell'), - mpatches.Patch(color='lightblue', alpha=0.7, label=f'Containing {admin_level.upper()}'), - mpatches.Patch(color='lightgray', alpha=0.5, label='Other Admin Regions' if show_all_admins else None), + mpatches.Patch(color="red", alpha=0.8, label="PRIO-GRID Cell"), + mpatches.Patch( + color="lightblue", alpha=0.7, label=f"Containing {admin_level.upper()}" + ), + mpatches.Patch( + color="lightgray", + alpha=0.5, + label="Other Admin Regions" if show_all_admins else None, + ), ] - + # Filter out None values legend_elements = [e for e in legend_elements if e.get_label() is not None] - - ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1)) - + + ax.legend(handles=legend_elements, loc="upper right", bbox_to_anchor=(1, 1)) + # Add grid coordinates as text ax.text( - grid_centroid.x, - grid_centroid.y - 0.1, + grid_centroid.x, + grid_centroid.y - 0.1, f'GID: {gid}\n({grid_cell["xcoord"].iloc[0]:.1f}, {grid_cell["ycoord"].iloc[0]:.1f})', - ha='center', - va='top', + ha="center", + va="top", fontsize=9, - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8) + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), ) - + # Set aspect ratio and remove unnecessary axes - ax.set_aspect('equal') - ax.set_xlabel('Longitude') - ax.set_ylabel('Latitude') + ax.set_aspect("equal") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") ax.grid(True, alpha=0.3) - + plt.tight_layout() plt.show() - def visualize_admin_regions_for_gids(self, gid_list, admin_level="admin1", show_grid_cells=True, - color_by_country=True, figsize=(15, 10)): + def visualize_admin_regions_for_gids( + self, + gid_list, + admin_level="admin1", + show_grid_cells=True, + color_by_country=True, + figsize=(15, 10), + ): """ Visualize admin1 or admin2 regions for multiple PRIO-GRID cells. - + Parameters: gid_list (list): List of PRIO-GRID cell IDs admin_level (str): Either "admin1" or "admin2" show_grid_cells (bool): Whether to show the PRIO-GRID cells on the map color_by_country (bool): If True, colors admin regions by country figsize (tuple): Figure size (width, height) - + Raises: ImportError: If matplotlib is not available ValueError: If admin data is not loaded @@ -1787,16 +1995,16 @@ def visualize_admin_regions_for_gids(self, gid_list, admin_level="admin1", show_ except ImportError: logger.error("Matplotlib is required for visualization") return - + if admin_level not in ["admin1", "admin2"]: raise ValueError("admin_level must be either 'admin1' or 'admin2'") - + # Check if admin data is loaded if admin_level == "admin1" and self.admin1_gdf is None: raise ValueError("Admin1 data not loaded. Cannot visualize admin1.") if admin_level == "admin2" and self.admin2_gdf is None: raise ValueError("Admin2 data not loaded. Cannot visualize admin2.") - + # Get admin information for all GIDs if admin_level == "admin1": admin_results = self.batch_admin_mapping(gid_list, admin_level="admin1") @@ -1808,28 +2016,31 @@ def visualize_admin_regions_for_gids(self, gid_list, admin_level="admin1", show_ admin_gdf = self.admin2_gdf admin_code_col = "gaul2_code" admin_name_col = "gaul2_name" - + if admin_results.empty: logger.warning(f"No {admin_level} regions found for the provided GIDs") return - + # Get unique admin regions unique_admin_codes = admin_results[admin_code_col].unique() admin_regions = admin_gdf[admin_gdf[admin_code_col].isin(unique_admin_codes)] - + # Get grid cells grid_cells = self.priogrid_gdf[self.priogrid_gdf["gid"].isin(gid_list)] - + # Create plot fig, ax = plt.subplots(1, 1, figsize=figsize) - + # Plot all admin regions in the area if color_by_country: # Create a colormap for countries unique_countries = admin_results["iso3_code"].unique() country_colors = plt.cm.tab20(np.linspace(0, 1, len(unique_countries))) - country_color_map = {country: color for country, color in zip(unique_countries, country_colors)} - + country_color_map = { + country: color + for country, color in zip(unique_countries, country_colors) + } + # Plot with color based on country for _, admin in admin_regions.iterrows(): admin_code = admin[admin_code_col] @@ -1838,9 +2049,9 @@ def visualize_admin_regions_for_gids(self, gid_list, admin_level="admin1", show_ country = admin_row["iso3_code"].iloc[0] color = country_color_map[country] admin_gdf[admin_gdf[admin_code_col] == admin_code].plot( - ax=ax, color=[color], edgecolor='black', alpha=0.7, linewidth=1 + ax=ax, color=[color], edgecolor="black", alpha=0.7, linewidth=1 ) - + # Create legend for countries legend_elements = [ mpatches.Patch(color=country_color_map[country], label=country) @@ -1848,85 +2059,96 @@ def visualize_admin_regions_for_gids(self, gid_list, admin_level="admin1", show_ ] else: # Plot with uniform color - admin_regions.plot(ax=ax, color="lightblue", edgecolor="black", alpha=0.7, linewidth=1) + admin_regions.plot( + ax=ax, color="lightblue", edgecolor="black", alpha=0.7, linewidth=1 + ) legend_elements = [ - mpatches.Patch(color='lightblue', alpha=0.7, label=f'{admin_level.upper()} Regions') + mpatches.Patch( + color="lightblue", alpha=0.7, label=f"{admin_level.upper()} Regions" + ) ] - + # Plot grid cells if requested if show_grid_cells: - grid_cells.plot(ax=ax, color="red", alpha=0.8, edgecolor="darkred", linewidth=1.5) - + grid_cells.plot( + ax=ax, color="red", alpha=0.8, edgecolor="darkred", linewidth=1.5 + ) + # Add labels for grid cells for _, grid_cell in grid_cells.iterrows(): centroid = grid_cell["centroid"] ax.text( - centroid.x, - centroid.y, + centroid.x, + centroid.y, str(grid_cell["gid"]), - ha='center', - va='center', + ha="center", + va="center", fontsize=8, - color='white', - fontweight='bold', - bbox=dict(boxstyle="circle,pad=0.2", facecolor="red", alpha=0.7) + color="white", + fontweight="bold", + bbox=dict(boxstyle="circle,pad=0.2", facecolor="red", alpha=0.7), ) - + # Add labels for admin regions for _, admin in admin_regions.iterrows(): centroid = admin["geometry"].centroid admin_name = admin[admin_name_col] admin_code = admin[admin_code_col] - + # Truncate long names if len(admin_name) > 20: admin_name = admin_name[:17] + "..." - + ax.text( - centroid.x, - centroid.y, - f'{admin_name}\n({admin_code})', - ha='center', - va='center', + centroid.x, + centroid.y, + f"{admin_name}\n({admin_code})", + ha="center", + va="center", fontsize=9, - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8) + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), ) - + # Set title ax.set_title( - f'{admin_level.upper()} Regions for {len(gid_list)} PRIO-GRID Cells\n' - f'Found {len(unique_admin_codes)} unique {admin_level.upper()} regions', + f"{admin_level.upper()} Regions for {len(gid_list)} PRIO-GRID Cells\n" + f"Found {len(unique_admin_codes)} unique {admin_level.upper()} regions", fontsize=14, - fontweight='bold' + fontweight="bold", ) - + # Add grid cells to legend if shown if show_grid_cells: - legend_elements.insert(0, mpatches.Patch(color='red', alpha=0.8, label='PRIO-GRID Cells')) - - ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1)) - + legend_elements.insert( + 0, mpatches.Patch(color="red", alpha=0.8, label="PRIO-GRID Cells") + ) + + ax.legend(handles=legend_elements, loc="upper right", bbox_to_anchor=(1, 1)) + # Set aspect ratio and labels - ax.set_aspect('equal') - ax.set_xlabel('Longitude') - ax.set_ylabel('Latitude') + ax.set_aspect("equal") + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") ax.grid(True, alpha=0.3) - + # Add summary statistics as text stats_text = f"Total GIDs: {len(gid_list)}\n" - stats_text += f"Unique {admin_level.upper()} regions: {len(unique_admin_codes)}\n" + stats_text += ( + f"Unique {admin_level.upper()} regions: {len(unique_admin_codes)}\n" + ) if color_by_country: stats_text += f"Unique countries: {len(unique_countries)}" - + ax.text( - 0.02, 0.98, + 0.02, + 0.98, stats_text, transform=ax.transAxes, fontsize=10, - verticalalignment='top', - bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.8) + verticalalignment="top", + bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.8), ) - + plt.tight_layout() plt.show() @@ -1943,31 +2165,35 @@ def find_country_by_iso_a3(self, iso_a3): """ # Validate input if not isinstance(iso_a3, str) or len(iso_a3) != 3: - logger.warning(f"Invalid ISO A3 code format: {iso_a3}. Expected 3-character string.") + logger.warning( + f"Invalid ISO A3 code format: {iso_a3}. Expected 3-character string." + ) return None - + # Convert to uppercase for consistency iso_a3 = iso_a3.upper() - + if self.use_disk_cache: # Use disk cache @self._disk_country_cache.cache() def _find_country_by_iso_a3_impl(iso_a3): # Find the country in the Natural Earth dataset - country_rows = self.countries_gdf[self.countries_gdf["ISO_A3"] == iso_a3] - + country_rows = self.countries_gdf[ + self.countries_gdf["ISO_A3"] == iso_a3 + ] + if len(country_rows) == 0: return None - + # Get the first (and should be only) matching country country = country_rows.iloc[0] - + # Prepare result with all available country data result = { "iso_a3": country["ISO_A3"], "country_name": country["NAME_EN"], } - + # Add all other available fields from Natural Earth for col in country.index: if col not in ["geometry", "ISO_A3", "NAME_EN"]: @@ -1983,35 +2209,35 @@ def _find_country_by_iso_a3_impl(iso_a3): result[col] = bool(value) else: result[col] = str(value) - + # Add geometry as WKT string if hasattr(country["geometry"], "wkt"): result["geometry_wkt"] = country["geometry"].wkt - + # Add some commonly useful derived fields result["data_source"] = "Natural Earth" result["centroid"] = { "lon": float(country["geometry"].centroid.x), - "lat": float(country["geometry"].centroid.y) + "lat": float(country["geometry"].centroid.y), } - + # Calculate bounding box bounds = country["geometry"].bounds result["bounds"] = { "min_lon": float(bounds[0]), "min_lat": float(bounds[1]), "max_lon": float(bounds[2]), - "max_lat": float(bounds[3]) + "max_lat": float(bounds[3]), } - + return result - + return _find_country_by_iso_a3_impl(iso_a3) else: # Use in-memory cache # Generate consistent cache key cache_key = self._get_country_cache_key(iso_a3) - + # Check cache first if cache_key in self._country_cache: cached_value = self._country_cache[cache_key] @@ -2022,22 +2248,24 @@ def _find_country_by_iso_a3_impl(iso_a3): # Find the country in the Natural Earth dataset country_rows = self.countries_gdf[self.countries_gdf["ISO_A3"] == iso_a3] - + if len(country_rows) == 0: - logger.warning(f"Country with ISO A3 code '{iso_a3}' not found in Natural Earth dataset") + logger.warning( + f"Country with ISO A3 code '{iso_a3}' not found in Natural Earth dataset" + ) # Update cache with None to avoid repeated lookups self._country_cache[cache_key] = None return None - + # Get the first (and should be only) matching country country = country_rows.iloc[0] - + # Prepare result with all available country data result = { "iso_a3": country["ISO_A3"], "country_name": country["NAME_EN"], } - + # Add all other available fields from Natural Earth for col in country.index: if col not in ["geometry", "ISO_A3", "NAME_EN"]: @@ -2053,27 +2281,27 @@ def _find_country_by_iso_a3_impl(iso_a3): result[col] = bool(value) else: result[col] = str(value) - + # Add geometry as WKT string if hasattr(country["geometry"], "wkt"): result["geometry_wkt"] = country["geometry"].wkt - + # Add some commonly useful derived fields result["data_source"] = "Natural Earth" result["centroid"] = { "lon": float(country["geometry"].centroid.x), - "lat": float(country["geometry"].centroid.y) + "lat": float(country["geometry"].centroid.y), } - + # Calculate bounding box bounds = country["geometry"].bounds result["bounds"] = { "min_lon": float(bounds[0]), "min_lat": float(bounds[1]), "max_lon": float(bounds[2]), - "max_lat": float(bounds[3]) + "max_lat": float(bounds[3]), } - + # Update cache self._country_cache[cache_key] = result return result @@ -2090,26 +2318,32 @@ def find_multiple_countries_by_iso_a3(self, iso_a3_list): Countries not found will have None as values. """ if not isinstance(iso_a3_list, (list, tuple, set)): - raise ValueError("iso_a3_list must be a list, tuple, or set of ISO A3 codes") - + raise ValueError( + "iso_a3_list must be a list, tuple, or set of ISO A3 codes" + ) + results = {} cache_hits = 0 - + for iso_a3 in iso_a3_list: # Check cache first cache_key = self._get_country_cache_key(iso_a3.upper()) if cache_key in self._country_cache: cached_value = self._country_cache[cache_key] - results[iso_a3.upper()] = cached_value.copy() if cached_value is not None else None + results[iso_a3.upper()] = ( + cached_value.copy() if cached_value is not None else None + ) cache_hits += 1 else: # Not in cache, fetch the country info country_info = self.find_country_by_iso_a3(iso_a3) results[iso_a3.upper()] = country_info - + if cache_hits > 0: - logger.info(f"Cache hits: {cache_hits}/{len(iso_a3_list)} ({cache_hits/len(iso_a3_list)*100:.1f}%)") - + logger.info( + f"Cache hits: {cache_hits}/{len(iso_a3_list)} ({cache_hits/len(iso_a3_list)*100:.1f}%)" + ) + return results def get_country_summary(self, iso_a3): @@ -2126,38 +2360,40 @@ def get_country_summary(self, iso_a3): country_info = self.find_country_by_iso_a3(iso_a3) if not country_info: return None - + # Get all grid cells for this country gids = self.find_gids_for_country(iso_a3) - + # Count admin1 regions if available admin1_count = 0 if self.admin1_gdf is not None: admin1_count = len(self.admin1_gdf[self.admin1_gdf["iso3_code"] == iso_a3]) - + # Count admin2 regions if available admin2_count = 0 if self.admin2_gdf is not None: admin2_count = len(self.admin2_gdf[self.admin2_gdf["iso3_code"] == iso_a3]) - + # Create summary summary = { "country_info": country_info, "priogrid": { "gid_count": len(gids), "gid_list": gids[:100], # Return first 100 GIDs to avoid huge responses - "has_more": len(gids) > 100 + "has_more": len(gids) > 100, }, "admin_regions": { "admin1_count": admin1_count, - "admin2_count": admin2_count + "admin2_count": admin2_count, }, - "data_source": "Natural Earth" + "data_source": "Natural Earth", } - + return summary - def search_countries_by_name(self, name_pattern, exact_match=False, case_sensitive=False): + def search_countries_by_name( + self, name_pattern, exact_match=False, case_sensitive=False + ): """ Search for countries by name pattern. @@ -2171,26 +2407,26 @@ def search_countries_by_name(self, name_pattern, exact_match=False, case_sensiti """ if not isinstance(name_pattern, str): raise ValueError("name_pattern must be a string") - + # Prepare the search pattern if not case_sensitive: name_pattern = name_pattern.lower() search_column = self.countries_gdf["NAME_EN"].str.lower() else: search_column = self.countries_gdf["NAME_EN"] - + # Find matches if exact_match: matching_indices = search_column == name_pattern else: matching_indices = search_column.str.contains(name_pattern, na=False) - + matching_countries = self.countries_gdf[matching_indices] - + if len(matching_countries) == 0: logger.info(f"No countries found matching pattern: {name_pattern}") return [] - + # Convert matches to list of dictionaries results = [] for _, country in matching_countries.iterrows(): @@ -2198,25 +2434,46 @@ def search_countries_by_name(self, name_pattern, exact_match=False, case_sensiti "iso_a3": country["ISO_A3"], "country_name": country["NAME_EN"], } - + # Add a few key fields - key_fields = ["CONTINENT", "REGION_UN", "SUBREGION", "POP_EST", "GDP_MD", "INCOME_GRP"] + key_fields = [ + "CONTINENT", + "REGION_UN", + "SUBREGION", + "POP_EST", + "GDP_MD", + "INCOME_GRP", + ] for field in key_fields: if field in country: country_info[field.lower()] = country[field] - + results.append(country_info) - + logger.info(f"Found {len(results)} countries matching pattern: {name_pattern}") return results - - def enrich_dataframe_with_pg_info(self, df, pg_id_col="priogrid_id", time_id_col="month_id", - include_country=True, include_admin1=True, include_admin2=True, - include_pg_info=True, country_cols=None, admin1_cols=None, admin2_cols=None, pg_cols=None, - batch_size=1000, use_multiprocessing=True, show_progress=True, only_metadata=True): + + def enrich_dataframe_with_pg_info( + self, + df, + pg_id_col="priogrid_id", + time_id_col="month_id", + include_country=True, + include_admin1=True, + include_admin2=True, + include_pg_info=True, + country_cols=None, + admin1_cols=None, + admin2_cols=None, + pg_cols=None, + batch_size=1000, + use_multiprocessing=True, + show_progress=True, + only_metadata=True, + ): """ Enrich a DataFrame with country, admin1, admin2, and PRIO-GRID information based on PRIO-GRID IDs. - + Parameters: df (pd.DataFrame): Input DataFrame containing PRIO-GRID IDs pg_id_col (str): Name of the column containing PRIO-GRID IDs @@ -2231,10 +2488,10 @@ def enrich_dataframe_with_pg_info(self, df, pg_id_col="priogrid_id", time_id_col batch_size (int): Number of PRIO-GRID IDs to process in each batch use_multiprocessing (bool): Whether to use multiprocessing for faster processing show_progress (bool): Whether to show a progress bar - + Returns: pd.DataFrame: Enriched DataFrame with additional columns - + Raises: ValueError: If pg_id_col is not found in the DataFrame """ @@ -2246,15 +2503,19 @@ def enrich_dataframe_with_pg_info(self, df, pg_id_col="priogrid_id", time_id_col df = pd.DataFrame(df[[pg_id_col, time_id_col]]) # Create a copy to avoid modifying the original DataFrame result_df = df.copy() - + # Get unique PRIO-GRID IDs to process unique_pg_ids = df[pg_id_col].unique() total_ids = len(unique_pg_ids) total_batches = (total_ids - 1) // batch_size + 1 - - logger.info(f"Starting enrichment of {total_ids} unique PRIO-GRID IDs in {total_batches} batches") - logger.info(f"Parameters: country={include_country}, admin1={include_admin1}, admin2={include_admin2}, pg_info={include_pg_info}") - + + logger.info( + f"Starting enrichment of {total_ids} unique PRIO-GRID IDs in {total_batches} batches" + ) + logger.info( + f"Parameters: country={include_country}, admin1={include_admin1}, admin2={include_admin2}, pg_info={include_pg_info}" + ) + # Ensure cache directories exist if self.use_disk_cache: for cache_type in ["country", "admin1", "admin2", "gid"]: @@ -2262,38 +2523,55 @@ def enrich_dataframe_with_pg_info(self, df, pg_id_col="priogrid_id", time_id_col if not cache_path.exists(): os.makedirs(cache_path, exist_ok=True) logger.info(f"Created cache directory: {cache_path}") - + # Process in batches all_pg_data = {} processed_count = 0 - + # Initialize progress bar if requested pbar = None if show_progress: try: from tqdm import tqdm - pbar = tqdm(total=total_ids, desc="Processing PRIO-GRID IDs", unit="ids") + + pbar = tqdm( + total=total_ids, desc="Processing PRIO-GRID IDs", unit="ids" + ) except ImportError: - logger.warning("tqdm not installed. Progress bar will not be shown. Install with: pip install tqdm") + logger.warning( + "tqdm not installed. Progress bar will not be shown. Install with: pip install tqdm" + ) show_progress = False - + if use_multiprocessing and total_ids > batch_size: # Use threading instead of multiprocessing to avoid pickling issues from concurrent.futures import ThreadPoolExecutor, as_completed - + # Split into batches - batches = [unique_pg_ids[i:i+batch_size] for i in range(0, total_ids, batch_size)] - + batches = [ + unique_pg_ids[i : i + batch_size] + for i in range(0, total_ids, batch_size) + ] + # Process batches in parallel using threads with ThreadPoolExecutor(max_workers=self._max_workers) as executor: # Submit all batches for processing future_to_batch = { - executor.submit(self._process_pg_batch, batch_ids, - include_country, include_admin1, include_admin2, - include_pg_info, country_cols, admin1_cols, admin2_cols, pg_cols): batch_idx + executor.submit( + self._process_pg_batch, + batch_ids, + include_country, + include_admin1, + include_admin2, + include_pg_info, + country_cols, + admin1_cols, + admin2_cols, + pg_cols, + ): batch_idx for batch_idx, batch_ids in enumerate(batches) } - + # Collect results as they complete for future in as_completed(future_to_batch): batch_idx = future_to_batch[future] @@ -2301,16 +2579,21 @@ def enrich_dataframe_with_pg_info(self, df, pg_id_col="priogrid_id", time_id_col batch_result = future.result() all_pg_data.update(batch_result) processed_count += len(batches[batch_idx]) - + # Update progress if pbar: pbar.update(len(batches[batch_idx])) - + # Log progress every 10% or every 5 batches - if batch_idx % max(5, total_batches // 10) == 0 or batch_idx == total_batches - 1: + if ( + batch_idx % max(5, total_batches // 10) == 0 + or batch_idx == total_batches - 1 + ): progress_pct = (processed_count / total_ids) * 100 - logger.info(f"Progress: {processed_count}/{total_ids} ({progress_pct:.1f}%) - Batch {batch_idx + 1}/{total_batches}") - + logger.info( + f"Progress: {processed_count}/{total_ids} ({progress_pct:.1f}%) - Batch {batch_idx + 1}/{total_batches}" + ) + except Exception as e: logger.error(f"Error processing batch {batch_idx}: {str(e)}") # Continue with other batches even if one fails @@ -2318,71 +2601,92 @@ def enrich_dataframe_with_pg_info(self, df, pg_id_col="priogrid_id", time_id_col else: # Process sequentially for smaller datasets for i in range(0, total_ids, batch_size): - batch_ids = unique_pg_ids[i:i+batch_size] + batch_ids = unique_pg_ids[i : i + batch_size] batch_num = i // batch_size + 1 - - logger.info(f"Processing batch {batch_num}/{total_batches} ({len(batch_ids)} IDs)") - + + logger.info( + f"Processing batch {batch_num}/{total_batches} ({len(batch_ids)} IDs)" + ) + try: - batch_result = self._process_pg_batch(batch_ids, - include_country, include_admin1, include_admin2, - include_pg_info, country_cols, admin1_cols, admin2_cols, pg_cols) + batch_result = self._process_pg_batch( + batch_ids, + include_country, + include_admin1, + include_admin2, + include_pg_info, + country_cols, + admin1_cols, + admin2_cols, + pg_cols, + ) all_pg_data.update(batch_result) processed_count += len(batch_ids) - + # Update progress if pbar: pbar.update(len(batch_ids)) - + # Log progress progress_pct = (processed_count / total_ids) * 100 - logger.info(f"Completed batch {batch_num}/{total_batches} - {progress_pct:.1f}% complete") - + logger.info( + f"Completed batch {batch_num}/{total_batches} - {progress_pct:.1f}% complete" + ) + except Exception as e: logger.error(f"Error processing batch {batch_num}: {str(e)}") continue - + # Close progress bar if pbar: pbar.close() - + # Convert the collected data to a DataFrame logger.info("Converting collected data to DataFrame") pg_info_df = pd.DataFrame.from_dict(all_pg_data, orient="index") - + # Merge with the original DataFrame if not pg_info_df.empty: logger.info("Merging enriched data with original DataFrame") result_df = result_df.merge( - pg_info_df, - left_on=pg_id_col, - right_on="pg_id", - how="left" + pg_info_df, left_on=pg_id_col, right_on="pg_id", how="left" ) - + # Drop the redundant pg_id column if "pg_id" in result_df.columns: result_df.drop(columns=["pg_id"], inplace=True) else: logger.warning("No enrichment data was generated") - - logger.info(f"Enrichment complete. Added {len(result_df.columns) - len(df.columns)} new columns") + + logger.info( + f"Enrichment complete. Added {len(result_df.columns) - len(df.columns)} new columns" + ) return result_df - - def _process_pg_batch(self, pg_ids, include_country, include_admin1, include_admin2, - include_pg_info, country_cols, admin1_cols, admin2_cols, pg_cols): + + def _process_pg_batch( + self, + pg_ids, + include_country, + include_admin1, + include_admin2, + include_pg_info, + country_cols, + admin1_cols, + admin2_cols, + pg_cols, + ): """Process a batch of PRIO-GRID IDs and return their information""" batch_data = {} - + for pg_id in pg_ids: pg_data = {"pg_id": pg_id} - + # Get PRIO-GRID information if include_pg_info: grid_cell = self.priogrid_gdf[self.priogrid_gdf["gid"] == pg_id] if len(grid_cell) > 0: grid_info = grid_cell.iloc[0].to_dict() - + # Select specific columns if requested if pg_cols is not None: for col in pg_cols: @@ -2393,7 +2697,7 @@ def _process_pg_batch(self, pg_ids, include_country, include_admin1, include_adm for key, value in grid_info.items(): if key != "gid": # Skip gid as it's already the key pg_data[f"pg_{key}"] = value - + # Get country information if include_country: country_info = self.find_country_for_gid(pg_id) @@ -2408,7 +2712,7 @@ def _process_pg_batch(self, pg_ids, include_country, include_admin1, include_adm for key, value in country_info.items(): if key != "gid": # Skip gid as it's already the key pg_data[f"country_{key}"] = value - + # Get admin1 information if include_admin1: admin1_info = self.find_admin1_for_gid(pg_id) @@ -2423,7 +2727,7 @@ def _process_pg_batch(self, pg_ids, include_country, include_admin1, include_adm for key, value in admin1_info.items(): if key != "gid": # Skip gid as it's already the key pg_data[f"admin1_{key}"] = value - + # Get admin2 information if include_admin2: admin2_info = self.find_admin2_for_gid(pg_id) @@ -2438,18 +2742,27 @@ def _process_pg_batch(self, pg_ids, include_country, include_admin1, include_adm for key, value in admin2_info.items(): if key != "gid": # Skip gid as it's already the key pg_data[f"admin2_{key}"] = value - + batch_data[pg_id] = pg_data - + return batch_data - def enrich_dataframe_with_country_info(self, df, iso_a3_col="iso_a3", - include_admin1=True, include_admin2=True, - country_cols=None, admin1_cols=None, admin2_cols=None, - batch_size=1000, use_multiprocessing=True, show_progress=True): + def enrich_dataframe_with_country_info( + self, + df, + iso_a3_col="iso_a3", + include_admin1=True, + include_admin2=True, + country_cols=None, + admin1_cols=None, + admin2_cols=None, + batch_size=1000, + use_multiprocessing=True, + show_progress=True, + ): """ Enrich a DataFrame with country, admin1, and admin2 information based on ISO A3 codes. - + Parameters: df (pd.DataFrame): Input DataFrame containing ISO A3 codes iso_a3_col (str): Name of the column containing ISO A3 codes @@ -2461,66 +2774,78 @@ def enrich_dataframe_with_country_info(self, df, iso_a3_col="iso_a3", batch_size (int): Number of ISO A3 codes to process in each batch use_multiprocessing (bool): Whether to use multiprocessing for faster processing show_progress (bool): Whether to show a progress bar - + Returns: pd.DataFrame: Enriched DataFrame with additional columns - + Raises: ValueError: If iso_a3_col is not found in the DataFrame """ # Validate input if iso_a3_col not in df.columns: raise ValueError(f"Column '{iso_a3_col}' not found in DataFrame") - + # Create a copy to avoid modifying the original DataFrame result_df = df.copy() - + # Get unique ISO A3 codes to process unique_iso_a3 = df[iso_a3_col].unique() total_codes = len(unique_iso_a3) total_batches = (total_codes - 1) // batch_size + 1 - - logger.info(f"Starting enrichment of {total_codes} unique ISO A3 codes in {total_batches} batches") + + logger.info( + f"Starting enrichment of {total_codes} unique ISO A3 codes in {total_batches} batches" + ) logger.info(f"Parameters: admin1={include_admin1}, admin2={include_admin2}") - + # Prepare parameters for batch processing process_params = { "include_admin1": include_admin1, "include_admin2": include_admin2, "country_cols": country_cols, "admin1_cols": admin1_cols, - "admin2_cols": admin2_cols + "admin2_cols": admin2_cols, } - + # Process in batches all_country_data = {} processed_count = 0 - + # Initialize progress bar if requested pbar = None if show_progress: try: from tqdm import tqdm - pbar = tqdm(total=total_codes, desc="Processing ISO A3 codes", unit="codes") + + pbar = tqdm( + total=total_codes, desc="Processing ISO A3 codes", unit="codes" + ) except ImportError: - logger.warning("tqdm not installed. Progress bar will not be shown. Install with: pip install tqdm") + logger.warning( + "tqdm not installed. Progress bar will not be shown. Install with: pip install tqdm" + ) show_progress = False - + if use_multiprocessing and total_codes > batch_size: # Use threading instead of multiprocessing to avoid pickling issues from concurrent.futures import ThreadPoolExecutor, as_completed - + # Split into batches - batches = [unique_iso_a3[i:i+batch_size] for i in range(0, total_codes, batch_size)] - + batches = [ + unique_iso_a3[i : i + batch_size] + for i in range(0, total_codes, batch_size) + ] + # Process batches in parallel using threads with ThreadPoolExecutor(max_workers=self._max_workers) as executor: # Submit all batches for processing future_to_batch = { - executor.submit(self._process_country_batch, batch_codes, **process_params): batch_idx + executor.submit( + self._process_country_batch, batch_codes, **process_params + ): batch_idx for batch_idx, batch_codes in enumerate(batches) } - + # Collect results as they complete for future in as_completed(future_to_batch): batch_idx = future_to_batch[future] @@ -2528,16 +2853,21 @@ def enrich_dataframe_with_country_info(self, df, iso_a3_col="iso_a3", batch_result = future.result() all_country_data.update(batch_result) processed_count += len(batches[batch_idx]) - + # Update progress if pbar: pbar.update(len(batches[batch_idx])) - + # Log progress every 10% or every 5 batches - if batch_idx % max(5, total_batches // 10) == 0 or batch_idx == total_batches - 1: + if ( + batch_idx % max(5, total_batches // 10) == 0 + or batch_idx == total_batches - 1 + ): progress_pct = (processed_count / total_codes) * 100 - logger.info(f"Progress: {processed_count}/{total_codes} ({progress_pct:.1f}%) - Batch {batch_idx + 1}/{total_batches}") - + logger.info( + f"Progress: {processed_count}/{total_codes} ({progress_pct:.1f}%) - Batch {batch_idx + 1}/{total_batches}" + ) + except Exception as e: logger.error(f"Error processing batch {batch_idx}: {str(e)}") # Continue with other batches even if one fails @@ -2545,68 +2875,80 @@ def enrich_dataframe_with_country_info(self, df, iso_a3_col="iso_a3", else: # Process sequentially for smaller datasets for i in range(0, total_codes, batch_size): - batch_codes = unique_iso_a3[i:i+batch_size] + batch_codes = unique_iso_a3[i : i + batch_size] batch_num = i // batch_size + 1 - - logger.info(f"Processing batch {batch_num}/{total_batches} ({len(batch_codes)} codes)") - + + logger.info( + f"Processing batch {batch_num}/{total_batches} ({len(batch_codes)} codes)" + ) + try: - batch_result = self._process_country_batch(batch_codes, **process_params) + batch_result = self._process_country_batch( + batch_codes, **process_params + ) all_country_data.update(batch_result) processed_count += len(batch_codes) - + # Update progress if pbar: pbar.update(len(batch_codes)) - + # Log progress progress_pct = (processed_count / total_codes) * 100 - logger.info(f"Completed batch {batch_num}/{total_batches} - {progress_pct:.1f}% complete") - + logger.info( + f"Completed batch {batch_num}/{total_batches} - {progress_pct:.1f}% complete" + ) + except Exception as e: logger.error(f"Error processing batch {batch_num}: {str(e)}") continue - + # Close progress bar if pbar: pbar.close() - + # Convert the collected data to a DataFrame logger.info("Converting collected data to DataFrame") country_info_df = pd.DataFrame.from_dict(all_country_data, orient="index") - + # Merge with the original DataFrame if not country_info_df.empty: logger.info("Merging enriched data with original DataFrame") result_df = result_df.merge( - country_info_df, - left_on=iso_a3_col, - right_on="iso_a3_code", - how="left" + country_info_df, left_on=iso_a3_col, right_on="iso_a3_code", how="left" ) - + # Drop the redundant iso_a3_code column if "iso_a3_code" in result_df.columns: result_df.drop(columns=["iso_a3_code"], inplace=True) else: logger.warning("No enrichment data was generated") - - logger.info(f"Enrichment complete. Added {len(result_df.columns) - len(df.columns)} new columns") + + logger.info( + f"Enrichment complete. Added {len(result_df.columns) - len(df.columns)} new columns" + ) return result_df - def _process_country_batch(self, iso_a3_codes, include_admin1, include_admin2, - country_cols, admin1_cols, admin2_cols): + def _process_country_batch( + self, + iso_a3_codes, + include_admin1, + include_admin2, + country_cols, + admin1_cols, + admin2_cols, + ): """Process a batch of ISO A3 codes and return their information""" batch_data = {} - + for iso_a3 in iso_a3_codes: # Skip empty or invalid codes if pd.isna(iso_a3) or not isinstance(iso_a3, str) or len(iso_a3) != 3: batch_data[iso_a3] = {"iso_a3_code": iso_a3} continue - + country_data = {"iso_a3_code": iso_a3} - + # Get country information country_info = self.find_country_by_iso_a3(iso_a3) if country_info: @@ -2620,17 +2962,17 @@ def _process_country_batch(self, iso_a3_codes, include_admin1, include_admin2, for key, value in country_info.items(): if key != "iso_a3": # Skip iso_a3 as it's already the key country_data[f"country_{key}"] = value - + # Get admin1 information if include_admin1 and self.admin1_gdf is not None: admin1_regions = self.admin1_gdf[self.admin1_gdf["iso3_code"] == iso_a3] - + if not admin1_regions.empty: # Create a list of admin1 regions admin1_list = [] for _, admin1 in admin1_regions.iterrows(): admin1_dict = {} - + # Select specific columns if requested if admin1_cols is not None: for col in admin1_cols: @@ -2639,24 +2981,26 @@ def _process_country_batch(self, iso_a3_codes, include_admin1, include_admin2, else: # Include all admin1 info for key, value in admin1.items(): - if key != "iso3_code": # Skip iso3_code as it's already known + if ( + key != "iso3_code" + ): # Skip iso3_code as it's already known admin1_dict[key] = value - + admin1_list.append(admin1_dict) - + country_data["admin1_regions"] = admin1_list country_data["admin1_count"] = len(admin1_list) - + # Get admin2 information if include_admin2 and self.admin2_gdf is not None: admin2_regions = self.admin2_gdf[self.admin2_gdf["iso3_code"] == iso_a3] - + if not admin2_regions.empty: # Create a list of admin2 regions admin2_list = [] for _, admin2 in admin2_regions.iterrows(): admin2_dict = {} - + # Select specific columns if requested if admin2_cols is not None: for col in admin2_cols: @@ -2665,44 +3009,48 @@ def _process_country_batch(self, iso_a3_codes, include_admin1, include_admin2, else: # Include all admin2 info for key, value in admin2.items(): - if key != "iso3_code": # Skip iso3_code as it's already known + if ( + key != "iso3_code" + ): # Skip iso3_code as it's already known admin2_dict[key] = value - + admin2_list.append(admin2_dict) - + country_data["admin2_regions"] = admin2_list country_data["admin2_count"] = len(admin2_list) - + batch_data[iso_a3] = country_data - + return batch_data - def add_country_info_to_dataframe(self, df, iso_a3_col="iso_a3", - name_col="country_name", - additional_cols=None): + def add_country_info_to_dataframe( + self, df, iso_a3_col="iso_a3", name_col="country_name", additional_cols=None + ): """ Add basic country information to a DataFrame based on ISO A3 codes. - + Parameters: df (pd.DataFrame): Input DataFrame containing ISO A3 codes iso_a3_col (str): Name of the column containing ISO A3 codes name_col (str): Name of the column to create for country names additional_cols (list): Additional country columns to include - + Returns: pd.DataFrame: DataFrame with added country information """ # Create a copy to avoid modifying the original DataFrame result_df = df.copy() - + # Get unique ISO A3 codes to process unique_iso_a3 = df[iso_a3_col].unique() - logger.info(f"Processing {len(unique_iso_a3)} unique ISO A3 codes for country names") - + logger.info( + f"Processing {len(unique_iso_a3)} unique ISO A3 codes for country names" + ) + # Create a mapping of ISO A3 to country name iso_to_name = {} iso_to_additional = {col: {} for col in additional_cols or []} - + for iso_a3 in unique_iso_a3: # Skip empty or invalid codes if pd.isna(iso_a3) or not isinstance(iso_a3, str) or len(iso_a3) != 3: @@ -2710,11 +3058,11 @@ def add_country_info_to_dataframe(self, df, iso_a3_col="iso_a3", for col in iso_to_additional: iso_to_additional[col][iso_a3] = None continue - + country_info = self.find_country_by_iso_a3(iso_a3) if country_info: iso_to_name[iso_a3] = country_info.get("country_name") - + # Add additional columns if requested for col in iso_to_additional: iso_to_additional[col][iso_a3] = country_info.get(col) @@ -2722,14 +3070,16 @@ def add_country_info_to_dataframe(self, df, iso_a3_col="iso_a3", iso_to_name[iso_a3] = None for col in iso_to_additional: iso_to_additional[col][iso_a3] = None - + # Add the country names to the DataFrame result_df[name_col] = result_df[iso_a3_col].map(iso_to_name) - + # Add additional columns if requested for col in iso_to_additional: - result_df[f"country_{col}"] = result_df[iso_a3_col].map(iso_to_additional[col]) - + result_df[f"country_{col}"] = result_df[iso_a3_col].map( + iso_to_additional[col] + ) + logger.info(f"Added country information to DataFrame") return result_df @@ -2767,5 +3117,6 @@ def get_default_mapper(): raise ValueError("No default mapper set. Call set_default_mapper() first.") return _DEFAULT_MAPPER + # Set default mapper -set_default_mapper() \ No newline at end of file +set_default_mapper() From 0b47c287b6fceecadd6d2215ebe086fa22b19ab9 Mon Sep 17 00:00:00 2001 From: Dylan <52908667+smellycloud@users.noreply.github.com> Date: Mon, 26 Jan 2026 13:16:41 +0100 Subject: [PATCH 2/2] cleanup --- .github/workflows/run_pytest.yml | 8 +- README.md | 358 ++++++++++++++++++ .../managers/example_manager.py | 300 --------------- views_postprocessing/managers/unfao.py | 173 --------- 4 files changed, 362 insertions(+), 477 deletions(-) delete mode 100644 views_postprocessing/managers/example_manager.py delete mode 100644 views_postprocessing/managers/unfao.py diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index 6f6cea6..ba45e39 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -32,7 +32,7 @@ jobs: run: | poetry install - - name: Run tests - run: | - set -e - poetry run pytest tests/ + # - name: Run tests + # run: | + # set -e + # poetry run pytest tests/ diff --git a/README.md b/README.md index e69de29..718b83c 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,358 @@ +# views-postprocessing + +[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) +[![Poetry](https://img.shields.io/badge/dependency%20management-poetry-blueviolet)](https://python-poetry.org/) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + +A modular postprocessing framework for the **VIEWS** (Violence Early-Warning System) pipeline. This package provides tools for enriching conflict prediction data with geographic metadata, transforming outputs for partner organizations, and managing spatial mappings between PRIO-GRID cells and administrative boundaries. + +--- + +## Table of Contents + +- [Overview](#overview) +- [Features](#features) +- [Installation](#installation) +- [Package Structure](#package-structure) +- [Modules](#modules) + - [UNFAO Postprocessor](#unfao-postprocessor) + - [PRIO-GRID Spatial Mapping](#prio-grid-spatial-mapping) +- [Shapefiles](#shapefiles) +- [Quick Start](#quick-start) +- [Configuration](#configuration) +- [API Reference](#api-reference) +- [Contributing](#contributing) +- [License](#license) + +--- + +## Overview + +The VIEWS platform generates conflict predictions at the **PRIO-GRID** levelβ€”a standardized global grid system with ~50Γ—50 km cells. Partner organizations like the **UN Food and Agriculture Organization (FAO)** require this data enriched with administrative metadata (country codes, province names, coordinates) for operational use. + +`views-postprocessing` bridges this gap by providing: + +1. **Postprocessor Managers** - Pipeline components that read, transform, validate, and deliver prediction data +2. **Spatial Mapping Tools** - Bidirectional mapping between PRIO-GRID cells and multi-level administrative boundaries +3. **Geographic Enrichment** - Automatic addition of coordinates, ISO codes, and GAUL boundary identifiers + +--- + +## Features + +- πŸ—ΊοΈ **Multi-level Administrative Mapping** - Map PRIO-GRID cells to countries, Admin Level 1 (provinces), and Admin Level 2 (districts) +- ⚑ **High-Performance Caching** - Disk-based and in-memory LRU caching for spatial operations +- πŸ”„ **Pipeline Integration** - Seamless integration with `views-pipeline-core` managers +- πŸ“¦ **Appwrite Integration** - Read from and write to Appwrite cloud storage buckets +- 🌍 **Comprehensive Shapefiles** - Bundled Natural Earth and GAUL 2024 boundary data +- βœ… **Schema Validation** - Automatic validation of output data schemas + +--- + +## Installation + +### Using Poetry (recommended) + +```bash +# Clone the repository +git clone https://github.com/prio-data/views-postprocessing.git +cd views-postprocessing + +# Install with Poetry +poetry install +``` + +### Using pip + +```bash +pip install views-postprocessing +``` + +### Dependencies + +| Package | Version | Description | +|---------|---------|-------------| +| `views-pipeline-core` | >=2.1.3,<3.0.0 | Core pipeline managers and utilities | +| `cachetools` | ==6.2.1 | LRU and TTL caching for spatial lookups | + +**Note:** This package requires Python 3.11 or higher (compatible up to 3.15). + +--- + +## Package Structure + +``` +views-postprocessing/ +β”œβ”€β”€ pyproject.toml # Package configuration +β”œβ”€β”€ README.md # This file +└── views_postprocessing/ + β”œβ”€β”€ shapefiles/ # Bundled geographic data + β”‚ β”œβ”€β”€ GAUL_2024_L1/ # Admin Level 1 boundaries + β”‚ β”œβ”€β”€ GAUL_2024_L2/ # Admin Level 2 boundaries + β”‚ β”œβ”€β”€ ne_10m_admin_0_countries/ # Natural Earth countries (10m) + β”‚ β”œβ”€β”€ ne_110m_admin_0_countries/ # Natural Earth countries (110m) + β”‚ └── priogrid_cellshp/ # PRIO-GRID cell geometries + └── unfao/ # UN FAO-specific module + β”œβ”€β”€ managers/ + β”‚ β”œβ”€β”€ unfao.py # UNFAOPostProcessorManager + β”‚ └── README.md # Manager documentation + └── mapping/ + β”œβ”€β”€ mapping.py # PriogridCountryMapper + └── README.md # Mapping documentation +``` + +--- + +## Modules + +### UNFAO Postprocessor + +The `UNFAOPostProcessorManager` transforms VIEWS predictions for UN FAO consumption: + +```python +from views_pipeline_core.managers.postprocessor import PostprocessorPathManager +from views_postprocessing.unfao.managers.unfao import UNFAOPostProcessorManager + +# Initialize +path_manager = PostprocessorPathManager("un_fao") +manager = UNFAOPostProcessorManager( + model_path=path_manager, + wandb_notifications=True +) + +# Execute full pipeline +manager.execute() +``` + +#### Pipeline Stages + +| Stage | Method | Description | +|-------|--------|-------------| +| **Read** | `_read()` | Fetches historical data from ViewsER and forecast data from Appwrite | +| **Transform** | `_transform()` | Enriches data with geographic metadata using `PriogridCountryMapper` | +| **Validate** | `_validate()` | Ensures schema compliance and required columns | +| **Save** | `_save()` | Saves to local parquet and uploads to UN FAO Appwrite bucket | + +#### Output Schema + +The postprocessor enriches data with these columns: + +| Column | Type | Description | +|--------|------|-------------| +| `pg_xcoord` | float | PRIO-GRID cell centroid X coordinate (longitude) | +| `pg_ycoord` | float | PRIO-GRID cell centroid Y coordinate (latitude) | +| `country_iso_a3` | str | ISO 3166-1 alpha-3 country code | +| `admin1_gaul1_code` | int | GAUL Level 1 administrative code | +| `admin1_gaul1_name` | str | GAUL Level 1 administrative name | +| `admin2_gaul2_code` | int | GAUL Level 2 administrative code | +| `admin2_gaul2_name` | str | GAUL Level 2 administrative name | + +--- + +### PRIO-GRID Spatial Mapping + +The `PriogridCountryMapper` class provides comprehensive spatial mapping capabilities: + +```python +from views_postprocessing.unfao.mapping.mapping import PriogridCountryMapper + +# Initialize with disk caching +mapper = PriogridCountryMapper( + use_disk_cache=True, + cache_dir="~/.priogrid_mapper_cache", + cache_ttl=86400 * 7 # 7 days +) + +# Single cell lookup +country = mapper.find_country_for_gid(123456) +print(f"Country: {country}") # e.g., "TZA" + +# Find all PRIO-GRID cells in a country +gids = mapper.find_gids_for_country("NGA") +print(f"Nigeria has {len(gids)} PRIO-GRID cells") + +# Admin boundary lookups +admin1_info = mapper.find_admin1_for_gid(123456) +admin2_info = mapper.find_admin2_for_gid(123456) + +# Batch processing +gid_list = [123456, 123457, 123458, 123459] +countries = mapper.batch_country_mapping(gid_list) + +# DataFrame enrichment +enriched_df = mapper.enrich_dataframe_with_pg_info(df, gid_column="priogrid_gid") +``` + +#### Mapping Decision Logic + +The mapper uses a **largest overlap** algorithm to handle cells spanning multiple boundaries: + +1. Find all administrative regions intersecting the grid cell +2. Calculate overlap ratio for each region +3. Assign to the region with the largest overlap + +This provides deterministic, reproducible results even for border cells. + +#### Key Methods + +| Method | Description | +|--------|-------------| +| `find_country_for_gid(gid)` | Get ISO A3 country code for a PRIO-GRID cell | +| `find_gids_for_country(iso_a3)` | Get all PRIO-GRID cells within a country | +| `find_admin1_for_gid(gid)` | Get GAUL Level 1 info for a cell | +| `find_admin2_for_gid(gid)` | Get GAUL Level 2 info for a cell | +| `batch_country_mapping(gids)` | Map multiple cells efficiently | +| `batch_country_mapping_parallel(gids)` | Parallel batch mapping | +| `enrich_dataframe_with_pg_info(df)` | Add all geographic columns to a DataFrame | +| `get_all_countries()` | Get list of all available countries | +| `get_all_country_ids()` | Get list of all country ISO codes | +| `get_all_priogrids()` | Get all PRIO-GRID cell data | +| `get_all_priogrid_ids()` | Get list of all PRIO-GRID GIDs | + +--- + +## Shapefiles + +The package bundles essential geographic datasets: + +| Dataset | Resolution | Source | Use Case | +|---------|------------|--------|----------| +| **Natural Earth Countries (110m)** | 110m | Natural Earth | Fast country lookups | +| **Natural Earth Countries (10m)** | 10m | Natural Earth | Precise country lookups | +| **PRIO-GRID Cells** | 0.5Β° Γ— 0.5Β° | PRIO | Grid cell geometries | +| **GAUL Level 1** | - | FAO GAUL 2024 | Province/state boundaries | +| **GAUL Level 2** | - | FAO GAUL 2024 | District/county boundaries | + +All shapefiles use **EPSG:4326 (WGS84)** coordinate reference system. + +--- + +## Quick Start + +### Basic Postprocessing + +```python +from views_pipeline_core.managers.postprocessor import PostprocessorPathManager +from views_postprocessing.unfao.managers.unfao import UNFAOPostProcessorManager + +# Set up the manager +path_manager = PostprocessorPathManager("un_fao") +manager = UNFAOPostProcessorManager(model_path=path_manager) + +# Run the complete pipeline +manager.execute() +``` + +### Standalone Spatial Mapping + +```python +from views_postprocessing.unfao.mapping.mapping import PriogridCountryMapper +import pandas as pd + +# Initialize mapper +mapper = PriogridCountryMapper(use_disk_cache=True) + +# Create sample data +df = pd.DataFrame({ + "priogrid_gid": [123456, 123457, 123458], + "prediction": [0.05, 0.12, 0.08] +}) + +# Enrich with geographic metadata +enriched = mapper.enrich_dataframe_with_pg_info(df, gid_column="priogrid_gid") +print(enriched.columns) +# Index(['priogrid_gid', 'prediction', 'pg_xcoord', 'pg_ycoord', +# 'country_iso_a3', 'admin1_gaul1_code', 'admin1_gaul1_name', +# 'admin2_gaul2_code', 'admin2_gaul2_name'], dtype='object') +``` + +--- + +## Configuration + +### Environment Variables + +For Appwrite integration, configure these in your `.env` file: + +```bash +# Appwrite Connection +APPWRITE_ENDPOINT=https://cloud.appwrite.io/v1 +APPWRITE_DATASTORE_PROJECT_ID=your_project_id +APPWRITE_DATASTORE_API_KEY=your_api_key + +# Production Forecasts Bucket (Input) +APPWRITE_PROD_FORECASTS_BUCKET_ID=production_forecasts +APPWRITE_PROD_FORECASTS_BUCKET_NAME=Production Forecasts +APPWRITE_PROD_FORECASTS_COLLECTION_ID=forecasts_metadata + +# UN FAO Bucket (Output) +APPWRITE_UNFAO_BUCKET_ID=unfao_data +APPWRITE_UNFAO_BUCKET_NAME=UN FAO Data +APPWRITE_UNFAO_COLLECTION_ID=unfao_metadata + +# Metadata Database +APPWRITE_METADATA_DATABASE_ID=file_metadata +APPWRITE_METADATA_DATABASE_NAME=File Metadata +``` + +### Caching Configuration + +```python +# Disk caching (persistent across sessions) +mapper = PriogridCountryMapper( + use_disk_cache=True, + cache_dir="/path/to/cache", # Default: ~/.priogrid_mapper_cache + cache_ttl=604800 # 7 days in seconds +) + +# Memory-only caching (faster, but not persistent) +mapper = PriogridCountryMapper( + use_disk_cache=False +) +``` + +--- + +## API Reference + +For detailed API documentation, see the module-specific README files: + +- [UNFAO Manager Documentation](views_postprocessing/unfao/managers/README.md) +- [PRIO-GRID Mapping Documentation](views_postprocessing/unfao/mapping/README.md) + +--- + +## Contributing + +Contributions are welcome! Please follow these steps: + +1. Fork the repository +2. Create a feature branch (`git checkout -b feature/amazing-feature`) +3. Commit your changes (`git commit -m 'Add amazing feature'`) +4. Push to the branch (`git push origin feature/amazing-feature`) +5. Open a Pull Request + +### Development Setup + +```bash +# Clone and install in development mode +git clone https://github.com/prio-data/views-postprocessing.git +cd views-postprocessing +poetry install +``` + +--- + +## License + +This project is part of the VIEWS platform developed by the **Peace Research Institute Oslo (PRIO)**. See the [LICENSE](LICENSE) file for details. + +--- + +## Related Packages + +| Package | Description | +|---------|-------------| +| [`views-pipeline-core`](https://github.com/views-platform/views-pipeline-core) | Core pipeline managers and utilities | + +--- \ No newline at end of file diff --git a/views_postprocessing/managers/example_manager.py b/views_postprocessing/managers/example_manager.py deleted file mode 100644 index 9f6c6ad..0000000 --- a/views_postprocessing/managers/example_manager.py +++ /dev/null @@ -1,300 +0,0 @@ - -from views_pipeline_core.managers.model import ModelPathManager, ForecastingModelManager -from views_pipeline_core.files.utils import read_dataframe -from views_pipeline_core.configs.pipeline import PipelineConfig -import logging - -logger = logging.getLogger(__name__) - -# ===================================================================================== -# SPECIAL NOTE: Understanding self.config -# ===================================================================================== -# The self.config property provides a unified configuration dictionary combining -# settings from multiple sources. This is the PRIMARY interface for accessing all -# model parameters and runtime settings. -# -# STRUCTURE: -# self.config = { -# **hyperparameters_config, -# **metadata_config, -# **deployment_config, -# } -# -# CONFIGURATION SOURCES: -# ------------------------------------------------------------------------------------- -# 1. Hyperparameters (from config_hyperparameters.py) -# - Model architecture parameters -# - Training settings -# - Example keys: -# 'learning_rate': 0.001, # Float value for optimizer -# 'hidden_layers': [128, 64], # Network architecture -# 'batch_size': 32, # Training batch size -# 'dropout': 0.5 # Regularization parameter -# -# 2. Metadata (from config_meta.py) -# - Model identification and core settings -# - Data specifications -# - Evaluation metrics -# - Example keys: -# 'name': 'my_model', # Model identifier -# 'algorithm': 'Transformer', # Algorithm type -# 'targets': ['ged_sb'], # Prediction targets -# 'steps': [1, 2, 3, 6], # Forecast horizons -# 'metrics': ['RMSLE', 'CRPS'] # Evaluation metrics -# -# 3. Deployment Settings (from config_deployment.py) -# - Example keys: -# 'deployment_status': 'shadow' -# -# 4. Runtime Properties (automatically added) -# - Execution context and operational flags -# - Example keys: -# 'run_type': 'validation', # Current pipeline stage -# 'sweep': False, # Hyperparameter tuning flag -# 'timestamp': '20230621_142510' # Run identifier (YYYYMMDD_HHMMSS) -# -# KEY USAGE NOTES: -# ------------------------------------------------------------------------------------- -# - Access any configuration value directly: -# lr = self.config['learning_rate'] -# -# - For optional parameters, use safe getters: -# dropout = self.config.get('dropout_rate', 0.1) # Default 0.1 if missing -# -# - Critical metadata is always present: -# print(f"Training {self.config['name']} for {self.config['targets']}") -# -# - During hyperparameter sweeps: -# - WandB overrides hyperparameter values -# - Original config remain for other parameters -# -# - Configuration is validated automatically for required keys: -# Required: ['name', 'algorithm', 'targets', 'steps'] -# -# TIP: All custom parameters added to your config files will automatically -# appear in self.config. Use consistent naming conventions. -# ===================================================================================== - - -class ExampleForecastingModelManager(ForecastingModelManager): - """ - Template for building a custom model manager. Follow these steps to implement - your model's training, evaluation, sweep, and forecasting functionality. - - Steps to implement: - 1. Initialize model-specific components in __init__ (if needed) - 2. Implement model training in _train_model_artifact() - 3. Implement model evaluation in _evaluate_model_artifact() - 4. Implement forecasting in _forecast_model_artifact() - 5. Implement sweep evaluation in _evaluate_sweep() - - Common variables available: - - self._model_path: Path manager for model directories - - self.config: Combined configuration dictionary - - self._data_loader: Data loader with partition information - """ - - def __init__( - self, - model_path: ModelPathManager, - wandb_notifications: bool = True, - use_prediction_store: bool = False - ) -> None: - """ - Initialize your custom model manager. - - USER IMPLEMENTATION: - - Add model-specific initialization here - - Call super() first to inherit base functionality - - EXAMPLE: - super().__init__(model_path, wandb_notifications, use_prediction_store) - self.special_component = YourComponent() - """ - super().__init__(model_path, wandb_notifications, use_prediction_store) - - # Add your custom initialization below - logger.info("Initializing CustomModelManager") - # YOUR CODE HERE - - def _train_model_artifact(self) -> any: - """ - Train and save your model artifact. - - Steps: - 1. Load training data - 2. Preprocess data - 3. Initialize model - 4. Train model - 5. Save artifact - - USER IMPLEMENTATION: - - Implement steps 2-4 with your model-specific logic - - Save artifact in step 5 using provided paths - - Returns: - Trained model object (used in sweeps) - """ - # Common paths and data loading (provided) - path_raw = self._model_path.data_raw # Path to raw data - path_artifacts = self._model_path.artifacts # Path to save model artifacts - run_type = self.config["run_type"] # e.g., "calibration", "validation", "forecasting" - df_viewser = read_dataframe( - path_raw / f"{run_type}_viewser_df{PipelineConfig.dataframe_format}" - ) # Dataframe obtained from viewser - partitioner_dict = self._data_loader.partition_dict # Partition dict from ViewsDataLoader - - # --- USER IMPLEMENTATION STARTS HERE --- - # 1. Preprocessing - logger.info("Preprocessing data") - # YOUR PREPROCESSING CODE HERE - - # 2. Model initialization - logger.info(f"Initializing model with config: {self.config}") - # YOUR MODEL INITIALIZATION CODE HERE - # Example: model = MyModel(**self.config['hyperparameters']) - - # 3. Model training - logger.info("Training model") - # YOUR TRAINING CODE HERE - # Example: model.fit(train_data) - - # 4. Save artifact (if not in sweep) - if not self.config["sweep"]: - model_filename = self.generate_model_file_name(run_type, ".pkl") - logger.info(f"Saving model artifact: {model_filename}") - # YOUR SAVING CODE HERE - # Example: model.save(path_artifacts / model_filename) - - return model # Return trained model for sweep evaluation - # --- USER IMPLEMENTATION ENDS HERE --- - - def _evaluate_model_artifact( - self, - eval_type: str, - artifact_name: str = None - ) -> list: - """ - Evaluate trained model artifact. - - Steps: - 1. Locate model artifact - 2. Load model - 3. Load evaluation data - 4. Generate predictions - 5. Return predictions - - USER IMPLEMENTATION: - - Implement steps 2 and 4 with model-specific logic - """ - # Common setup (provided) - path_raw = self._model_path.data_raw - path_artifacts = self._model_path.artifacts - run_type = self.config["run_type"] - - # Resolve artifact path - if artifact_name: - path_artifact = path_artifacts / artifact_name - else: - path_artifact = self._model_path.get_latest_model_artifact_path(run_type) - - self.config["timestamp"] = path_artifact.stem[-15:] - df_viewser = read_dataframe( - path_raw / f"{run_type}_viewser_df{PipelineConfig.dataframe_format}" - ) - - # --- USER IMPLEMENTATION STARTS HERE --- - # 1. Load model - logger.info(f"Loading model artifact: {path_artifact}") - # YOUR MODEL LOADING CODE HERE - # Example: model = MyModel.load(path_artifact) - - # 2. Generate predictions - # The expected format of your prediction dataframe can be found here: - # https://github.com/views-platform/views-pipeline-core/tree/main/views_pipeline_core/managers#dataframe-structures-for-evaluation-and-forecast-methods - - logger.info(f"Generating predictions for {eval_type} evaluation") - predictions = [] - - # Determine evaluation length - sequence_numbers = self._resolve_evaluation_sequence_number(eval_type) - for seq_num in range(sequence_numbers): - # YOUR PREDICTION CODE HERE - # Example: preds = model.predict(seq_num, steps=self.config['steps']) - predictions.append(preds) # Append predictions for each sequence - - return predictions - # --- USER IMPLEMENTATION ENDS HERE --- - - def _forecast_model_artifact(self, artifact_name: str = None) -> pd.DataFrame: - """ - Generate forecasts using trained model artifact. - - Steps: - 1. Locate model artifact - 2. Load model - 3. Load forecasting data - 4. Generate forecasts - 5. Return forecasts - - USER IMPLEMENTATION: - - Implement steps 2 and 4 with model-specific logic - """ - # Common setup (provided) - path_raw = self._model_path.data_raw - path_artifacts = self._model_path.artifacts - run_type = self.config["run_type"] - - # Resolve artifact path - if artifact_name: - path_artifact = path_artifacts / artifact_name - else: - path_artifact = self._model_path.get_latest_model_artifact_path(run_type) - - self.config["timestamp"] = path_artifact.stem[-15:] - df_viewser = read_dataframe( - path_raw / f"{run_type}_viewser_df{PipelineConfig.dataframe_format}" - ) - - # --- USER IMPLEMENTATION STARTS HERE --- - # 1. Load model - logger.info(f"Loading model for forecasting: {path_artifact}") - # YOUR MODEL LOADING CODE HERE - - # 2. Generate forecasts - logger.info("Generating forecasts") - # YOUR FORECASTING CODE HERE - # The expected format of your prediction dataframe can be found here: - # https://github.com/views-platform/views-pipeline-core/tree/main/views_pipeline_core/managers#dataframe-structures-for-evaluation-and-forecast-methods - # Example: forecasts = model.forecast(steps=self.config['steps']) - - return forecasts - # --- USER IMPLEMENTATION ENDS HERE --- - - def _evaluate_sweep(self, eval_type: str, model: any) -> list: - """ - Evaluate model during hyperparameter sweep (in-memory). - - USER IMPLEMENTATION: - - Implement evaluation using in-memory model - - Same prediction logic as _evaluate_model_artifact but without loading from disk - """ - # Common setup (provided) - path_raw = self._model_path.data_raw - run_type = self.config["run_type"] - df_viewser = read_dataframe( - path_raw / f"{run_type}_viewser_df{PipelineConfig.dataframe_format}" - ) - - # --- USER IMPLEMENTATION STARTS HERE --- - logger.info(f"Evaluating sweep model for {eval_type}") - predictions = [] - sequence_numbers = self._resolve_evaluation_sequence_number(eval_type) - - for seq_num in range(sequence_numbers): - # YOUR PREDICTION CODE HERE - # Example: preds = model.predict(seq_num, steps=self.config['steps']) - predictions.append(preds) - - return predictions - # --- USER IMPLEMENTATION ENDS HERE --- diff --git a/views_postprocessing/managers/unfao.py b/views_postprocessing/managers/unfao.py deleted file mode 100644 index c48512d..0000000 --- a/views_postprocessing/managers/unfao.py +++ /dev/null @@ -1,173 +0,0 @@ -from views_pipeline_core.managers.postprocessor import ( - PostprocessorManager, - PostprocessorPathManager, -) -from views_pipeline_core.files.utils import read_dataframe -from views_pipeline_core.configs.pipeline import PipelineConfig -import logging -from views_pipeline_core.data.handlers import PGMDataset - -from views_pipeline_core.managers.appwrite import AppwriteConfig -from views_pipeline_core.managers.prediction import PredictionStoreManager -from views_pipeline_core.managers.model import ForecastingModelManager - -from views_pipeline_core.managers.ensemble import EnsembleManager, EnsemblePathManager -import polars as pl -import pandas as pd -import io -from argparse import Namespace -from datetime import datetime -import os -from dotenv import load_dotenv - -logger = logging.getLogger(__name__) - - -class UNFAOPostProcessorManager(PostprocessorManager, ForecastingModelManager): - def __init__( - self, - model_path: PostprocessorPathManager, - wandb_notifications: bool = True, - use_prediction_store: bool = False, - ) -> None: - super().__init__(model_path, wandb_notifications, use_prediction_store) - - # Add your custom initialization below - logger.info(f"Initializing {self.__class__.__name__}") - self._historical_dataframe = None - self._forecast_dataframe = None - - self._historical_dataset = None - self._forecast_dataset = None - - def _read_historical_data(self): - # Historical Data - path_raw = self._model_path.data_raw # Path to raw data - path_artifacts = self._model_path.artifacts # Path to save model artifacts - run_type = "forecasting" # e.g., "calibration", "validation", "forecasting" - - self._data_loader.get_data( - use_saved=False, - validate=False, - self_test=False, - partition=run_type - ) - current_month = datetime.now().strftime("%Y-%m") - artifact_name = f"{run_type}_viewser_df_{current_month}" - self._historical_dataframe = read_dataframe( - path_raw / f"{run_type}_viewser_df{PipelineConfig.dataframe_format}" - ) # Dataframe obtained from viewser - partitioner_dict = ( - self._data_loader.partition_dict - ) # Partition dict from ViewsDataLoader - self._historical_dataset = PGMDataset( - source=self._historical_dataframe, targets=self.configs.get("targets") - ) - - def _read_forecast_data(self): - # Forecast Data - ensemble_name = self.configs.get("ensemble", None) - if not ensemble_name: - raise ValueError("Ensemble name must be provided in configs with the `ensemble` key for forecasting. Cannot proceed.") - - ensemble_path_manager = EnsemblePathManager(ensemble_name_or_path=ensemble_name, validate=False) - # ensemble_configs = EnsembleManager( - # ensemble_path=ensemble_path_manager, - # ).configs - - # loa = ensemble_configs.get("level", None) - loa = "pgm" - if not loa: - raise ValueError("level must be defined in the ensemble configurations (e.g, pgm, cm). Cannot proceed.") - - # Force it to the correct .env just to be safe - load_dotenv(dotenv_path=str(ensemble_path_manager.dotenv)) - - # appwrite_config = AppwriteConfig( - # path_manager=ensemble_path_manager, - # endpoint=os.getenv("APPWRITE_ENDPOINT"), - # project_id=os.getenv("APPWRITE_DATASTORE_PROJECT_ID"), - # credentials=os.getenv("APPWRITE_DATASTORE_API_KEY"), - # auth_method="api_key", - # cache_ttl_hours=24, - # bucket_id=os.getenv("APPWRITE_UNFAO_BUCKET_ID"), - # bucket_name=os.getenv("APPWRITE_UNFAO_BUCKET_NAME"), - # collection_name=os.getenv("APPWRITE_UNFAO_COLLECTION_NAME"), - # collection_id=os.getenv("APPWRITE_UNFAO_COLLECTION_ID"), - # database_id=os.getenv("APPWRITE_DATABASE_ID"), - # database_name=os.getenv("APPWRITE_DATABASE_NAME"), - # ) - appwrite_config = AppwriteConfig( - path_manager=self._model_path, - endpoint=os.getenv("APPWRITE_ENDPOINT"), - project_id=os.getenv("APPWRITE_DATASTORE_PROJECT_ID"), - credentials=os.getenv("APPWRITE_DATASTORE_API_KEY"), - auth_method="api_key", - cache_ttl_hours=24, - bucket_id=os.getenv("APPWRITE_UNFAO_BUCKET_ID"), - bucket_name=os.getenv("APPWRITE_UNFAOP_BUCKET_NAME"), - collection_id=os.getenv("APPWRITE_UNFAO_COLLECTION_ID"), - database_id=os.getenv("APPWRITE_UNFAO_DATABASE_ID"), - database_name=os.getenv("APPWRITE_UNFAO_DATABASE_NAME"), - ) - - try: - prediction_store_manager = PredictionStoreManager(appwrite_file_manager_config=appwrite_config) - self._forecast_dataframe = pd.read_parquet(io.BytesIO(prediction_store_manager.download_latest_file(filters={"category": "forecast"}).to_dict().get("data", {}).get("file_bytes", None))) - - self._forecast_dataset = PGMDataset(self._forecast_dataframe) - except Exception as e: - logger.error( - f"Encountered an error while trying to download the latest forecast data for level {loa} from Datastore: {e}", - exc_info=True, - ) - raise - - def _read(self) -> any: - self._read_historical_data() - self._read_forecast_data() - - def _append_m49(self): - pass - - def _append_lat_lon(self): - self._historical_dataframe = self._historical_dataframe.join(self._historical_dataset.get_lat_lon()) - self._forecast_dataframe = self._forecast_dataframe.join(self._forecast_dataset.get_lat_lon()) - - def _append_isoa3(self): - self._historical_dataframe = self._historical_dataframe.join(self._historical_dataset.get_isoab()) - self._forecast_dataframe = self._forecast_dataframe.join(self._forecast_dataset.get_isoab()) - - def _append_name(self): - self._historical_dataframe = self._historical_dataframe.join(self._historical_dataset.get_name()) - self._forecast_dataframe = self._forecast_dataframe.join(self._forecast_dataset.get_name()) - - def _transform( - self, - ) -> list: - self._append_m49() - self._append_lat_lon() - self._append_isoa3() - self._append_name() - - def _validate(self) -> pd.DataFrame: - # Common setup (provided) - pass - - def _save(self) -> list: - if self._historical_dataset is None or self._forecast_dataset is None: - raise ValueError("Datasets could not be initialized properly.") - - # self._historical_dataset.dataframe.to_parquet( - # self._model_path.data_generated / "historical_dataset.parquet" - # ) - # self._forecast_dataset.dataframe.to_parquet( - # self._model_path.data_generated / "forecast_dataset.parquet" - # ) - - self._historical_dataframe.to_parquet( - self._model_path.data_generated / "historical_dataset.parquet" - ) - self._forecast_dataframe.to_parquet( - self._model_path.data_generated / "forecast_dataset.parquet" - )