entropice/create_grid.py

232 lines
6.7 KiB
Python
Raw Normal View History

2025-09-26 11:05:29 +02:00
"""Create a global hexagonal grid using H3.
Author: Tobias Hölzer
Date: 09. June 2025
"""
from typing import Literal
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cyclopts
import geopandas as gpd
import h3
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xdggs
import xvec # noqa: F401
from rich import pretty, print, traceback
from shapely.geometry import Polygon
from shapely.ops import transform
from stopuhr import stopwatch
from xdggs.healpix import HealpixInfo
traceback.install()
pretty.install()
@stopwatch("Create a global hex grid")
def create_global_hex_grid(resolution):
"""Create a global hexagonal grid using H3.
Args:
resolution (int): H3 resolution level (0-15)
Returns:
GeoDataFrame: Global hexagonal grid
"""
# Generate hexagons
hex0_cells = h3.get_res0_cells()
if resolution > 0:
hex_cells = []
for hex0_cell in hex0_cells:
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
else:
hex_cells = hex0_cells
# Initialize lists to store hex information
hex_list = []
hex_id_list = []
hex_area_list = []
# Convert each hex ID to a polygon
for hex_id in hex_cells:
boundary_coords = h3.cell_to_boundary(hex_id)
hex_polygon = Polygon(boundary_coords)
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
hex_area = h3.cell_area(hex_id, unit="km^2")
hex_list.append(hex_polygon)
hex_id_list.append(hex_id)
hex_area_list.append(hex_area)
# Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326")
return grid
@stopwatch("Create a global HEALPix grid")
def create_global_healpix_grid(level: int):
"""Create a global HEALPix grid.
Args:
level (int): HEALPix level (0-12)
Returns:
GeoDataFrame: Global HEALPix grid
"""
grid_info = HealpixInfo(level=level, indexing_scheme="nested")
healpix_ds = xr.Dataset(
coords={
"cell_ids": (
"cells",
np.arange(12 * 4**grid_info.level),
grid_info.to_dict(),
)
}
).pipe(xdggs.decode)
cell_ids = healpix_ds.cell_ids.values
geometry = healpix_ds.dggs.cell_boundaries()
# Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326")
grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2
return grid
@stopwatch("Filter grid")
def filter_permafrost_grid(grid: gpd.GeoDataFrame):
"""Filter an existing grid to permafrost extent & remove water.
Args:
grid (gpd.GeoDataFrame): Input grid
Returns:
gpd.GeoDataFrame: Filtered grid
"""
# Filter for Permafrost region (> 50° latitude)
grid = grid[grid.geometry.bounds.miny > 50]
# Filter for Arctic Sea (<85° latitude)
grid = grid[grid.geometry.bounds.maxy < 85]
# Convert to arctic stereographic projection
grid = grid.to_crs("EPSG:3413")
# Filter out non-land areas (e.g., oceans)
water_mask = gpd.read_file("./data/simplified-water-polygons-split-3857/simplified_water_polygons.shp")
water_mask = water_mask.to_crs("EPSG:3413")
ov = gpd.overlay(grid, water_mask, how="intersection")
ov["area"] = ov.geometry.area / 1e6 # Convert to km^2
ov = ov.groupby("cell_id").agg({"area": "sum"})
grid["water_area"] = grid["cell_id"].map(ov.area).fillna(0)
grid["land_area"] = grid["cell_area"] - grid["water_area"]
grid["land_ratio"] = grid["land_area"] / grid["cell_area"]
# Filter for land areas (> 10% land)
grid = grid[grid["land_ratio"] > 0.1]
return grid
def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
"""Vizualize the grid on a polar stereographic map.
Args:
data (gpd.GeoDataFrame): The grid data to visualize.
grid (str): The type of grid (e.g., "hex" or "healpix").
level (int): The level of the grid.
Returns:
plt.Figure: The matplotlib figure object.
"""
fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()})
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree())
# Add features
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white")
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey")
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=":")
ax.add_feature(cfeature.LAKES, alpha=0.5)
ax.add_feature(cfeature.RIVERS)
# Add gridlines
gl = ax.gridlines(draw_labels=True)
gl.top_labels = False
gl.right_labels = False
# Plot grid cells, coloring by 'cell_area'
data = data.to_crs("EPSG:4326")
is_anti_meridian = data.bounds.apply(lambda b: (b.maxx - b.minx) > 180, axis=1)
data = data[~is_anti_meridian]
data.plot(
ax=ax,
column="cell_area",
cmap="viridis",
legend=True,
transform=ccrs.PlateCarree(),
edgecolor="k",
linewidth=0.2,
aspect="equal",
alpha=0.5,
)
ax.set_title(f"{grid.capitalize()} grid ({level=})", fontsize=14)
# Compute a circle in axes coordinates, which we can use as a boundary
# for the map. We can pan/zoom as much as we like - the boundary will be
# permanently circular.
theta = np.linspace(0, 2 * np.pi, 100)
center, radius = [0.5, 0.5], 0.5
verts = np.vstack([np.sin(theta), np.cos(theta)]).T
circle = mpath.Path(verts * radius + center)
ax.set_boundary(circle, transform=ax.transAxes)
return fig
def cli(grid: Literal["hex", "healpix"], level: int):
"""CLI entry point."""
print(f"Creating {grid} grid at level {level}...")
if grid == "hex":
grid_gdf = create_global_hex_grid(level)
elif grid == "healpix":
grid_gdf = create_global_healpix_grid(level)
else:
print(f"Unknown grid type: {grid}")
return
grid_gdf = filter_permafrost_grid(grid_gdf)
print(f"Number of cells at level {level}: {len(grid_gdf)}")
if not len(grid_gdf):
print("No valid grid cells found.")
return
grid_gdf.to_parquet(f"./data/grids/permafrost_{grid}{level}_grid.parquet")
print(f"Saved to ./data/grids/permafrost_{grid}{level}_grid.parquet")
fig = vizualize_grid(grid_gdf, grid, level)
fig.savefig(f"./figures/permafrost_{grid}{level}_grid.png", dpi=300)
print(f"Saved figure to ./figures/permafrost_{grid}{level}_grid.png")
plt.close(fig)
if __name__ == "__main__":
cyclopts.run(cli)