"""Create a global hexagonal grid using H3. Author: Tobias Hölzer Date: 09. June 2025 """ import os from pathlib import Path from typing import Literal import cartopy.crs as ccrs import cartopy.feature as cfeature import cyclopts import geopandas as gpd import h3 import matplotlib.path as mpath import matplotlib.pyplot as plt import numpy as np import xarray as xr import xdggs import xvec # noqa: F401 from rich import pretty, print, traceback from shapely.geometry import Polygon from shapely.ops import transform from stopuhr import stopwatch from xdggs.healpix import HealpixInfo traceback.install() pretty.install() DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts" GRIDS_DIR = DATA_DIR / "grids" FIGURES_DIR = DATA_DIR / "figures" GRIDS_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True) @stopwatch("Create a global hex grid") def create_global_hex_grid(resolution): """Create a global hexagonal grid using H3. Args: resolution (int): H3 resolution level (0-15) Returns: GeoDataFrame: Global hexagonal grid """ # Generate hexagons hex0_cells = h3.get_res0_cells() if resolution > 0: hex_cells = [] for hex0_cell in hex0_cells: hex_cells.extend(h3.cell_to_children(hex0_cell, resolution)) else: hex_cells = hex0_cells # Initialize lists to store hex information hex_list = [] hex_id_list = [] hex_area_list = [] # Convert each hex ID to a polygon for hex_id in hex_cells: boundary_coords = h3.cell_to_boundary(hex_id) hex_polygon = Polygon(boundary_coords) hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat) hex_area = h3.cell_area(hex_id, unit="km^2") hex_list.append(hex_polygon) hex_id_list.append(hex_id) hex_area_list.append(hex_area) # Create GeoDataFrame grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326") return grid @stopwatch("Create a global HEALPix grid") def create_global_healpix_grid(level: int): """Create a global HEALPix grid. Args: level (int): HEALPix level (0-12) Returns: GeoDataFrame: Global HEALPix grid """ grid_info = HealpixInfo(level=level, indexing_scheme="nested") healpix_ds = xr.Dataset( coords={ "cell_ids": ( "cells", np.arange(12 * 4**grid_info.level), grid_info.to_dict(), ) } ).pipe(xdggs.decode) cell_ids = healpix_ds.cell_ids.values geometry = healpix_ds.dggs.cell_boundaries() # Create GeoDataFrame grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2 return grid @stopwatch("Filter grid") def filter_permafrost_grid(grid: gpd.GeoDataFrame): """Filter an existing grid to permafrost extent & remove water. Args: grid (gpd.GeoDataFrame): Input grid Returns: gpd.GeoDataFrame: Filtered grid """ # Filter for Permafrost region (> 50° latitude) grid = grid[grid.geometry.bounds.miny > 50] # Filter for Arctic Sea (<85° latitude) grid = grid[grid.geometry.bounds.maxy < 85] # Convert to arctic stereographic projection grid = grid.to_crs("EPSG:3413") # Filter out non-land areas (e.g., oceans) water_mask = gpd.read_file(DATA_DIR / "simplified-water-polygons-split-3857/simplified_water_polygons.shp") water_mask = water_mask.to_crs("EPSG:3413") ov = gpd.overlay(grid, water_mask, how="intersection") ov["area"] = ov.geometry.area / 1e6 # Convert to km^2 ov = ov.groupby("cell_id").agg({"area": "sum"}) grid["water_area"] = grid["cell_id"].map(ov.area).fillna(0) grid["land_area"] = grid["cell_area"] - grid["water_area"] grid["land_ratio"] = grid["land_area"] / grid["cell_area"] # Filter for land areas (> 10% land) grid = grid[grid["land_ratio"] > 0.1] return grid def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure: """Vizualize the grid on a polar stereographic map. Args: data (gpd.GeoDataFrame): The grid data to visualize. grid (str): The type of grid (e.g., "hex" or "healpix"). level (int): The level of the grid. Returns: plt.Figure: The matplotlib figure object. """ fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()}) ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) # Add features ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.BORDERS, linestyle=":") ax.add_feature(cfeature.LAKES, alpha=0.5) ax.add_feature(cfeature.RIVERS) # Add gridlines gl = ax.gridlines(draw_labels=True) gl.top_labels = False gl.right_labels = False # Plot grid cells, coloring by 'cell_area' data = data.to_crs("EPSG:4326") is_anti_meridian = data.bounds.apply(lambda b: (b.maxx - b.minx) > 180, axis=1) data = data[~is_anti_meridian] data.plot( ax=ax, column="cell_area", cmap="viridis", legend=True, transform=ccrs.PlateCarree(), edgecolor="k", linewidth=0.2, aspect="equal", alpha=0.5, ) ax.set_title(f"{grid.capitalize()} grid ({level=})", fontsize=14) # Compute a circle in axes coordinates, which we can use as a boundary # for the map. We can pan/zoom as much as we like - the boundary will be # permanently circular. theta = np.linspace(0, 2 * np.pi, 100) center, radius = [0.5, 0.5], 0.5 verts = np.vstack([np.sin(theta), np.cos(theta)]).T circle = mpath.Path(verts * radius + center) ax.set_boundary(circle, transform=ax.transAxes) return fig def cli(grid: Literal["hex", "healpix"], level: int): """CLI entry point.""" print(f"Creating {grid} grid at level {level}...") if grid == "hex": grid_gdf = create_global_hex_grid(level) elif grid == "healpix": grid_gdf = create_global_healpix_grid(level) else: print(f"Unknown grid type: {grid}") return grid_gdf = filter_permafrost_grid(grid_gdf) print(f"Number of cells at level {level}: {len(grid_gdf)}") if not len(grid_gdf): print("No valid grid cells found.") return grid_file = GRIDS_DIR / f"permafrost_{grid}{level}_grid.parquet" grid_gdf.to_parquet(grid_file) print(f"Saved to {grid_file.resolve()}") fig = vizualize_grid(grid_gdf, grid, level) fig_file = FIGURES_DIR / f"permafrost_{grid}{level}_grid.png" fig.savefig(fig_file, dpi=300) print(f"Saved figure to {fig_file.resolve()}") plt.close(fig) def main(): # noqa: D103 cyclopts.run(cli) if __name__ == "__main__": main()