Restructure to steps

This commit is contained in:
Tobias Hölzer 2025-10-21 18:42:01 +02:00
parent 2af5c011a3
commit ce4c728e1a
10 changed files with 1377 additions and 640 deletions

View file

@ -0,0 +1,245 @@
"""Create a global hexagonal grid using H3.
Author: Tobias Hölzer
Date: 09. June 2025
"""
import os
from pathlib import Path
from typing import Literal
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cyclopts
import geopandas as gpd
import h3
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xdggs
import xvec # noqa: F401
from rich import pretty, print, traceback
from shapely.geometry import Polygon
from shapely.ops import transform
from stopuhr import stopwatch
from xdggs.healpix import HealpixInfo
traceback.install()
pretty.install()
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = DATA_DIR / "figures"
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
@stopwatch("Create a global hex grid")
def create_global_hex_grid(resolution):
"""Create a global hexagonal grid using H3.
Args:
resolution (int): H3 resolution level (0-15)
Returns:
GeoDataFrame: Global hexagonal grid
"""
# Generate hexagons
hex0_cells = h3.get_res0_cells()
if resolution > 0:
hex_cells = []
for hex0_cell in hex0_cells:
hex_cells.extend(h3.cell_to_children(hex0_cell, resolution))
else:
hex_cells = hex0_cells
# Initialize lists to store hex information
hex_list = []
hex_id_list = []
hex_area_list = []
# Convert each hex ID to a polygon
for hex_id in hex_cells:
boundary_coords = h3.cell_to_boundary(hex_id)
hex_polygon = Polygon(boundary_coords)
hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat)
hex_area = h3.cell_area(hex_id, unit="km^2")
hex_list.append(hex_polygon)
hex_id_list.append(hex_id)
hex_area_list.append(hex_area)
# Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326")
return grid
@stopwatch("Create a global HEALPix grid")
def create_global_healpix_grid(level: int):
"""Create a global HEALPix grid.
Args:
level (int): HEALPix level (0-12)
Returns:
GeoDataFrame: Global HEALPix grid
"""
grid_info = HealpixInfo(level=level, indexing_scheme="nested")
healpix_ds = xr.Dataset(
coords={
"cell_ids": (
"cells",
np.arange(12 * 4**grid_info.level),
grid_info.to_dict(),
)
}
).pipe(xdggs.decode)
cell_ids = healpix_ds.cell_ids.values
geometry = healpix_ds.dggs.cell_boundaries()
# Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326")
grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2
return grid
@stopwatch("Filter grid")
def filter_permafrost_grid(grid: gpd.GeoDataFrame):
"""Filter an existing grid to permafrost extent & remove water.
Args:
grid (gpd.GeoDataFrame): Input grid
Returns:
gpd.GeoDataFrame: Filtered grid
"""
# Filter for Permafrost region (> 50° latitude)
grid = grid[grid.geometry.bounds.miny > 50]
# Filter for Arctic Sea (<85° latitude)
grid = grid[grid.geometry.bounds.maxy < 85]
# Convert to arctic stereographic projection
grid = grid.to_crs("EPSG:3413")
# Filter out non-land areas (e.g., oceans)
water_mask = gpd.read_file(DATA_DIR / "simplified-water-polygons-split-3857/simplified_water_polygons.shp")
water_mask = water_mask.to_crs("EPSG:3413")
ov = gpd.overlay(grid, water_mask, how="intersection")
ov["area"] = ov.geometry.area / 1e6 # Convert to km^2
ov = ov.groupby("cell_id").agg({"area": "sum"})
grid["water_area"] = grid["cell_id"].map(ov.area).fillna(0)
grid["land_area"] = grid["cell_area"] - grid["water_area"]
grid["land_ratio"] = grid["land_area"] / grid["cell_area"]
# Filter for land areas (> 10% land)
grid = grid[grid["land_ratio"] > 0.1]
return grid
def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
"""Vizualize the grid on a polar stereographic map.
Args:
data (gpd.GeoDataFrame): The grid data to visualize.
grid (str): The type of grid (e.g., "hex" or "healpix").
level (int): The level of the grid.
Returns:
plt.Figure: The matplotlib figure object.
"""
fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()})
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree())
# Add features
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white")
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey")
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=":")
ax.add_feature(cfeature.LAKES, alpha=0.5)
ax.add_feature(cfeature.RIVERS)
# Add gridlines
gl = ax.gridlines(draw_labels=True)
gl.top_labels = False
gl.right_labels = False
# Plot grid cells, coloring by 'cell_area'
data = data.to_crs("EPSG:4326")
is_anti_meridian = data.bounds.apply(lambda b: (b.maxx - b.minx) > 180, axis=1)
data = data[~is_anti_meridian]
data.plot(
ax=ax,
column="cell_area",
cmap="viridis",
legend=True,
transform=ccrs.PlateCarree(),
edgecolor="k",
linewidth=0.2,
aspect="equal",
alpha=0.5,
)
ax.set_title(f"{grid.capitalize()} grid ({level=})", fontsize=14)
# Compute a circle in axes coordinates, which we can use as a boundary
# for the map. We can pan/zoom as much as we like - the boundary will be
# permanently circular.
theta = np.linspace(0, 2 * np.pi, 100)
center, radius = [0.5, 0.5], 0.5
verts = np.vstack([np.sin(theta), np.cos(theta)]).T
circle = mpath.Path(verts * radius + center)
ax.set_boundary(circle, transform=ax.transAxes)
return fig
def cli(grid: Literal["hex", "healpix"], level: int):
"""CLI entry point."""
print(f"Creating {grid} grid at level {level}...")
if grid == "hex":
grid_gdf = create_global_hex_grid(level)
elif grid == "healpix":
grid_gdf = create_global_healpix_grid(level)
else:
print(f"Unknown grid type: {grid}")
return
grid_gdf = filter_permafrost_grid(grid_gdf)
print(f"Number of cells at level {level}: {len(grid_gdf)}")
if not len(grid_gdf):
print("No valid grid cells found.")
return
grid_file = GRIDS_DIR / f"permafrost_{grid}{level}_grid.parquet"
grid_gdf.to_parquet(grid_file)
print(f"Saved to {grid_file.resolve()}")
fig = vizualize_grid(grid_gdf, grid, level)
fig_file = FIGURES_DIR / f"permafrost_{grid}{level}_grid.png"
fig.savefig(fig_file, dpi=300)
print(f"Saved figure to {fig_file.resolve()}")
plt.close(fig)
def main(): # noqa: D103
cyclopts.run(cli)
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,101 @@
"""Extract satellite embeddings from Google Earth Engine and map them to a grid."""
import os
import warnings
from pathlib import Path
from typing import Literal
import cyclopts
import ee
import geemap
import geopandas as gpd
import numpy as np
import pandas as pd
from rich import pretty, print, traceback
from rich.progress import track
# Filter out the GeoDataFrame.swapaxes deprecation warning
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
pretty.install()
traceback.install()
ee.Initialize(project="ee-tobias-hoelzer")
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
def cli(grid: Literal["hex", "healpix"], level: int, backup_intermediate: bool = False):
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
Args:
grid (Literal["hex", "healpix"]): The grid type to use.
level (int): The grid level to use.
backup_intermediate (bool, optional): Whether to backup intermediate results. Defaults to False.
"""
gridname = f"permafrost_{grid}{level}"
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
for year in track(range(2017, 2025), total=8, description="Processing years..."):
embedding_collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
embedding_collection = embedding_collection.filterDate(f"{year}-01-01", f"{year}-12-31")
aggs = ["median", "stdDev", "min", "max", "mean", "p1", "p5", "p25", "p75", "p95", "p99"]
bands = [f"A{str(i).zfill(2)}_{agg}" for i in range(64) for agg in aggs]
def extract_embedding(feature):
# Filter collection by geometry
geom = feature.geometry()
embedding = embedding_collection.filterBounds(geom).mosaic()
# Get mean embedding value for the geometry
mean_dict = embedding.reduceRegion(
reducer=ee.Reducer.median()
.combine(ee.Reducer.stdDev(), sharedInputs=True)
.combine(ee.Reducer.minMax(), sharedInputs=True)
.combine(ee.Reducer.mean(), sharedInputs=True)
.combine(ee.Reducer.percentile([1, 5, 25, 75, 95, 99]), sharedInputs=True),
geometry=geom,
)
# Add mean embedding values as properties to the feature
return feature.set(mean_dict)
# Process grid in batches of 100
batch_size = 100
all_results = []
n_batches = len(grid) // batch_size
for batch_num, batch_grid in track(
enumerate(np.array_split(grid, n_batches)),
description="Processing batches...",
total=n_batches,
):
# Convert batch to EE FeatureCollection
eegrid_batch = ee.FeatureCollection(batch_grid.to_crs("epsg:4326").__geo_interface__)
# Apply embedding extraction to batch
eeegrid_batch = eegrid_batch.map(extract_embedding)
df_batch = geemap.ee_to_df(eeegrid_batch)
# Store batch results
all_results.append(df_batch)
# Save batch immediately to disk as backup
if backup_intermediate:
batch_filename = f"{gridname}_embeddings-{year}_batch{batch_num:06d}.parquet"
batch_result = batch_grid.merge(df_batch[[*bands, "cell_id"]], on="cell_id", how="left")
batch_result.to_parquet(EMBEDDINGS_DIR / f"{batch_filename}")
# Combine all batch results
df = pd.concat(all_results, ignore_index=True)
embeddings_on_grid = grid.merge(df[[*bands, "cell_id"]], on="cell_id", how="left")
embeddings_file = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
embeddings_on_grid.to_parquet(embeddings_file)
print(f"Saved embeddings for year {year} to {embeddings_file.resolve()}.")
def main(): # noqa: D103
cyclopts.run(cli)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,9 @@
#!/bin/bash
# uv run alpha-earth --grid hex --level 3
uv run alpha-earth --grid hex --level 4
uv run alpha-earth --grid hex --level 5
uv run alpha-earth --grid healpix --level 6
uv run alpha-earth --grid healpix --level 7
uv run alpha-earth --grid healpix --level 8
uv run alpha-earth --grid healpix --level 9

66
steps/s1_1_era5/cds.py Normal file
View file

@ -0,0 +1,66 @@
"""Download ERA5 data from the Copernicus Data Store.
Web platform: https://cds.climate.copernicus.eu
"""
import os
import re
from pathlib import Path
import cdsapi
import cyclopts
from rich import pretty, print, traceback
traceback.install()
pretty.install()
DATA_DIR = Path(os.environ.get("DATA_DIR", "../../data")) / "entropyc-rts"
def hourly(years: str):
"""Download ERA5 data from the Copernicus Data Store.
Args:
years (str): Years to download, seperated by a '-'.
"""
assert re.compile(r"^\d{4}-\d{4}$").match(years), "Years must be in the format 'YYYY-YYYY'"
start_year, end_year = map(int, years.split("-"))
assert 1950 <= start_year <= end_year <= 2024, "Years must be between 1950 and 2024"
dataset = "reanalysis-era5-single-levels"
client = cdsapi.Client(wait_until_complete=False)
outdir = (DATA_DIR / "era5/cds").resolve()
outdir.mkdir(parents=True, exist_ok=True)
print(f"Downloading ERA5 data from {start_year} to {end_year}...")
for y in range(start_year, end_year + 1):
for month in [f"{i:02d}" for i in range(1, 13)]:
request = {
"product_type": ["reanalysis"],
"variable": [
"2m_temperature",
"total_precipitation",
"snow_depth",
"snow_density",
"snowfall",
"lake_ice_temperature",
"surface_sensible_heat_flux",
],
"year": [str(y)],
"month": [month],
"day": [f"{i:02d}" for i in range(1, 32)],
"time": [f"{i:02d}:00" for i in range(0, 24)],
"data_format": "netcdf",
"download_format": "unarchived",
"area": [85, -180, 50, 180],
}
outpath = outdir / f"era5_{y}_{month}.zip"
client.retrieve(dataset, request).download(str(outpath))
print(f"Downloaded {dataset} for {y}-{month}")
if __name__ == "__main__":
cyclopts.run(hourly)

576
steps/s1_1_era5/era5.py Normal file
View file

@ -0,0 +1,576 @@
"""Download and preprocess ERA5 data.
Variables of Interest:
- 2 metre temperature (t2m) [instant]
- Total precipitation (tp) [accum]
- Snow Fall (sf) [accum]
- Snow cover (snowc) [instant]
- Snow depth (sde) [instant]
- Surface sensible heat flux (sshf) [accum]
- Lake ice bottom temperature (lblt) [instant]
Naming patterns:
- Instant Variables are downloaded already as statistically aggregated (lossy),
therefore their names get the aggregation as suffix
- Accumulation Variables are downloaded as totals, their names stay the same
Daily Variables (downloaded from hourly data):
- t2m_max
- t2m_min
- snowc_mean
- sde_mean
- lblt_max
- tp
- sf
- sshf
Derived Daily Variables:
- t2m_daily_avg
- t2m_daily_range
- t2m_daily_skew
- thawing_degree_days
- freezing_degree_days
- thawing_days
- freezing_days
- precipitation_occurrences
- snowfall_occurrences
- snow_isolation (snowc * sde)
Monthly Variables:
- t2m_monthly_max
- t2m_monthly_min
- tp_monthly_sum
- sf_monthly_sum
- snowc_monthly_mean
- sde_monthly_mean
- sshf_monthly_sum
- lblt_monthly_max
- t2m_monthly_avg
- t2m_monthly_range_avg
- t2m_monthly_skew_avg
- thawing_degree_days_monthly
- freezing_degree_days_monthly
- thawing_days_monthly
- freezing_days_monthly
- precipitation_occurrences_monthly TODO: Rename to precipitation_days_monthly?
- snowfall_occurrences_monthly TODO: Rename to snowfall_days_monthly?
- snow_isolation_monthly_mean
Yearly Variables:
- TODO
# TODO Variables:
- Day of first thaw (yearly)
- Day of last thaw (yearly)
- Thawing period length (yearly)
- Freezing period length (yearly)
Author: Tobias Hölzer
Date: 09. June 2025
"""
import os
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Literal
import cyclopts
import dask.distributed as dd
import geopandas as gpd
import odc.geo
import odc.geo.xr
import pandas as pd
import shapely
import shapely.ops
import xarray as xr
from numcodecs.zarr3 import Blosc
from rich import pretty, print, traceback
from rich.progress import track
from shapely.geometry import LineString, Polygon
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd])
pretty.install()
cli = cyclopts.App()
# TODO: Directly handle download on a grid level - this is more what the zarr access is indented to do
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) / "entropyc-rts"
ERA5_DIR = DATA_DIR / "era5"
DAILY_RAW_PATH = ERA5_DIR / "daily_raw.zarr"
def _get_grid_paths(
agg: Literal["daily", "monthly", "summer", "winter", "yearly"],
grid: Literal["hex", "healpix"],
level: int,
):
gridname = f"permafrost_{grid}{level}"
aligned_path = ERA5_DIR / f"{agg}_{gridname}.zarr"
return aligned_path
min_lat = 50
max_lat = 83.7 # Ensures the right Chunks Size (90 - 64 / 10 + 0.1)
min_time = "1990-01-01"
max_time = "2024-12-31"
today = time.strftime("%Y-%m-%d")
# ================
# === Download ===
# ================
def create_encoding(ds: xr.Dataset):
"""Create compression encoding for zarr dataset storage.
Creates Blosc compression configuration for all data variables and coordinates
in the dataset using zstd compression with level 9.
Args:
ds (xr.Dataset): The xarray Dataset to create encoding for.
Returns:
dict: Encoding dictionary with compression settings for each variable.
"""
# encoding = {var: {"compressors": BloscCodec(cname="zlib", clevel=9)} for var in ds.data_vars}
encoding = {var: {"compressors": Blosc(cname="zstd", clevel=9)} for var in [*ds.data_vars, *ds.coords]}
return encoding
def download_daily_aggregated():
"""Download and aggregate ERA5 data to daily resolution.
Downloads ERA5 reanalysis data from the DESTINE Earth Data Hub and aggregates
it to daily resolution. Includes temperature extremes, precipitation, snow,
and surface heat flux variables.
The function downloads hourly data and creates daily aggregates:
- Temperature: daily min/max
- Precipitation and snowfall: daily totals
- Snow cover and depth: daily means
- Surface heat flux: daily totals
- Lake ice temperature: daily max
The aggregated data is saved to a zarr file with compression.
"""
era5 = xr.open_dataset(
"https://data.earthdatahub.destine.eu/era5/reanalysis-era5-land-no-antartica-v0.zarr",
storage_options={"client_kwargs": {"trust_env": True}},
chunks={},
# chunks={},
engine="zarr",
).rename({"valid_time": "time"})
subset = {
"latitude": slice(max_lat, min_lat),
}
# Compute the clostest chunk-start to min_time, to avoid problems with cropped chunks at the start
tchunksize = era5.chunksizes["time"][0]
era5_chunk_starts = pd.date_range(era5.time.min().item(), era5.time.max().item(), freq=f"{tchunksize}h")
closest_chunk_start = era5_chunk_starts[
era5_chunk_starts.get_indexer([pd.to_datetime(min_time)], method="ffill")[0]
]
subset["time"] = slice(str(closest_chunk_start), max_time)
era5 = era5.sel(**subset)
daily_raw = xr.merge(
[
# Instant
era5.t2m.resample(time="1D").max().rename("t2m_max"),
era5.t2m.resample(time="1D").min().rename("t2m_min"),
era5.snowc.resample(time="1D").mean().rename("snowc_mean"),
era5.sde.resample(time="1D").mean().rename("sde_mean"),
era5.lblt.resample(time="1D").max().rename("lblt_max"),
# Accum
era5.tp.resample(time="1D").sum().rename("tp"),
era5.sf.resample(time="1D").sum().rename("sf"),
era5.sshf.resample(time="1D").sum().rename("sshf"),
]
)
# Assign attributes
daily_raw["t2m_max"].attrs = {"long_name": "Daily maximum 2 metre temperature", "units": "K"}
daily_raw["t2m_min"].attrs = {"long_name": "Daily minimum 2 metre temperature", "units": "K"}
daily_raw["tp"].attrs = {"long_name": "Daily total precipitation", "units": "m"}
daily_raw["sf"].attrs = {"long_name": "Daily total snow fall", "units": "m"}
daily_raw["snowc_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"}
daily_raw["sde_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"}
daily_raw["sshf"].attrs = {"long_name": "Daily total surface sensible heat flux", "units": "J/m²"}
daily_raw["lblt_max"].attrs = {"long_name": "Daily maximum lake ice bottom temperature", "units": "K"}
daily_raw = daily_raw.odc.assign_crs("epsg:4326")
daily_raw = daily_raw.drop_vars(["surface", "number", "depthBelowLandLayer"])
daily_raw.to_zarr(DAILY_RAW_PATH, mode="w", encoding=create_encoding(daily_raw), consolidated=False)
@cli.command
def download():
"""Download ERA5 data using Dask cluster for parallel processing.
Creates a local Dask cluster and downloads daily aggregated ERA5 data.
The cluster is configured with a single worker with 10 threads and 100GB
memory limit for optimal performance.
"""
with (
dd.LocalCluster(n_workers=1, threads_per_worker=10, memory_limit="100GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
download_daily_aggregated()
print(f"Downloaded and aggregated ERA5 data to {DAILY_RAW_PATH.resolve()}.")
# ===========================
# === Spatial Aggregation ===
# ===========================
def _crosses_antimeridian(geom: Polygon) -> bool:
coords = shapely.get_coordinates(geom)
crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any()
return crosses_any_meridian and abs(coords[:, 0]).max() > 90
def _split_antimeridian_cell(geom: Polygon) -> list[Polygon]:
# Assumes that it is a antimeridian hex
coords = shapely.get_coordinates(geom)
for i in range(coords.shape[0]):
if coords[i, 0] < 0:
coords[i, 0] += 360
geom = Polygon(coords)
antimeridian = LineString([[180, -90], [180, 90]])
polys = shapely.ops.split(geom, antimeridian)
return list(polys.geoms)
def _check_geobox(geobox):
x, y = geobox.shape
return x > 1 and y > 1
def extract_cell_data(idx: int, geom: Polygon) -> xr.Dataset:
"""Extract ERA5 data for a specific grid cell geometry.
Extracts and spatially averages ERA5 data within the bounds of a grid cell.
Handles antimeridian-crossing cells by splitting them appropriately.
Args:
idx (int): Index of the grid cell.
geom (Polygon): Polygon geometry of the grid cell.
Returns:
xr.Dataset: The computed cell dataset
"""
daily_raw = xr.open_zarr(DAILY_RAW_PATH, consolidated=False).set_coords("spatial_ref")
# cell.geometry is a shapely Polygon
if not _crosses_antimeridian(geom):
geoms = [geom]
# Split geometry in case it crossed antimeridian
else:
geoms = _split_antimeridian_cell(geom)
cell_data = []
for geom in geoms:
geom = odc.geo.Geometry(geom, crs="epsg:4326")
if not _check_geobox(daily_raw.odc.geobox.enclosing(geom)):
continue
# TODO: use mean for instant variables, sum for accum variables
cell_data.append(daily_raw.odc.crop(geom).drop_vars("spatial_ref").mean(["latitude", "longitude"]))
if len(cell_data) == 0:
return False
elif len(cell_data) == 1:
cell_data = cell_data[0]
else:
cell_data = xr.concat(cell_data, dim="part").mean("part")
cell_data = cell_data.expand_dims({"cell": [idx]}).compute()
return cell_data
@cli.command
def spatial_agg(
grid: Literal["hex", "healpix"],
level: int,
n_workers: int = 10,
executor: Literal["threads", "processes"] = "threads",
):
"""Perform spatial aggregation of ERA5 data to grid cells.
Loads a grid and spatially aggregates ERA5 data to each grid cell using
parallel processing. Creates an empty zarr file first, then fills it
with extracted data for each cell.
Args:
grid ("hex" | "healpix"): Grid type.
level (int): Grid resolution level.
n_workers (int, optional): Number of parallel workers to use. Defaults to 10.
executor ("threads" | "processes"): The type of parallel executor pool to use. Defaults to threads.
"""
gridname = f"permafrost_{grid}{level}"
daily_grid_path = _get_grid_paths("daily", grid, level)
grid = gpd.read_parquet(DATA_DIR / f"grids/{gridname}_grid.parquet")
# Create an empty zarr array with the right dimensions
daily_raw = xr.open_zarr(DAILY_RAW_PATH, consolidated=False).set_coords("spatial_ref")
assert {"latitude", "longitude", "time"} == set(daily_raw.dims), (
f"Expected dims ('latitude', 'longitude', 'time'), got {daily_raw.dims}"
)
assert daily_raw.odc.crs == "epsg:4326", f"Expected CRS 'epsg:4326', got {daily_raw.odc.crs}"
daily = (
xr.zeros_like(daily_raw.isel(latitude=0, longitude=0))
.expand_dims({"cell": [idx for idx, _ in grid.iterrows()]})
.chunk({"cell": min(len(grid), 1000), "time": len(daily_raw.time)}) # ~50MB chunks
)
daily.to_zarr(daily_grid_path, mode="w", consolidated=False, encoding=create_encoding(daily))
print(f"Created empty zarr at {daily_grid_path.resolve()} with shape {daily.sizes}.")
print(f"Starting spatial matching of {len(grid)} cells with {n_workers} workers...")
ExecutorCls = ThreadPoolExecutor if executor == "threads" else ProcessPoolExecutor
with ExecutorCls(max_workers=n_workers) as executor:
futures = {
executor.submit(extract_cell_data, idx, row.geometry): idx
for idx, row in grid.to_crs("epsg:4326").iterrows()
}
for future in track(as_completed(futures), total=len(futures), description="Processing cells"):
idx = futures[future]
try:
cell_data = future.result()
if not cell_data:
print(f"Cell {idx} did not overlap with ERA5 data.")
cell_data.to_zarr(daily_grid_path, region="auto", consolidated=False)
print(f"Successfully written cell {idx}")
except Exception as e:
print(f"{type(e)} processing cell {idx}: {e}")
print("Finished spatial matching.")
# ============================
# === Temporal Aggregation ===
# ============================
def daily_enrich(grid: Literal["hex", "healpix"], level: int) -> xr.Dataset:
"""Enrich daily ERA5 data with derived climate variables.
Loads spatially aligned ERA5 data and computes additional climate variables.
Creates derived variables including temperature statistics, degree days, and occurrence indicators.
Derived variables include:
- Daily average and range temperature
- Temperature skewness
- Thawing and freezing degree days
- Thawing and freezing day counts
- Precipitation and snowfall occurrences
- Snow isolation index
Args:
grid ("hex", "healpix"): Grid type.
level (int): Grid resolution level.
Returns:
xr.Dataset: Enriched dataset with original and derived variables.
"""
daily_grid_path = _get_grid_paths("daily", grid, level)
daily = xr.open_zarr(daily_grid_path, consolidated=False).set_coords("spatial_ref")
assert {"cell", "time"} == set(daily.dims), f"Expected dims ('cell', 'time'), got {daily.dims}"
# Formulas based on Groeke et. al. (2025) Stochastic Weather generation...
daily["t2m_avg"] = (daily.t2m_max + daily.t2m_min) / 2
daily.t2m_avg.attrs = {"long_name": "Daily average 2 metre temperature", "units": "K"}
daily["t2m_range"] = daily.t2m_max - daily.t2m_min
daily.t2m_range.attrs = {"long_name": "Daily range of 2 metre temperature", "units": "K"}
daily["t2m_skew"] = (daily.t2m_avg - daily.t2m_min) / daily.t2m_range
daily.t2m_skew.attrs = {"long_name": "Daily skewness of 2 metre temperature"}
daily["thawing_degree_days"] = (daily.t2m_avg - 273.15).clip(min=0)
daily.thawing_degree_days.attrs = {"long_name": "Thawing degree days", "units": "K"}
daily["freezing_degree_days"] = (273.15 - daily.t2m_avg).clip(min=0)
daily.freezing_degree_days.attrs = {"long_name": "Freezing degree days", "units": "K"}
daily["thawing_days"] = (daily.t2m_avg > 273.15).astype(int)
daily.thawing_days.attrs = {"long_name": "Thawing days"}
daily["freezing_days"] = (daily.t2m_avg < 273.15).astype(int)
daily.freezing_days.attrs = {"long_name": "Freezing days"}
daily["precipitation_occurrences"] = (daily.tp > 0).astype(int)
daily.precipitation_occurrences.attrs = {"long_name": "Precipitation occurrences"}
daily["snowfall_occurrences"] = (daily.sf > 0).astype(int)
daily.snowfall_occurrences.attrs = {"long_name": "Snowfall occurrences"}
daily["snow_isolation"] = daily.snowc_mean * daily.sde_mean
daily.snow_isolation.attrs = {"long_name": "Snow isolation"}
return daily
def monthly_aggregate(grid: Literal["hex", "healpix"], level: int):
"""Aggregate enriched daily ERA5 data to monthly resolution.
Takes the enriched daily ERA5 data and creates monthly aggregates using
appropriate statistical functions for each variable type. Temperature
variables use min/max/mean, accumulation variables use sums, and derived
variables use appropriate aggregations.
The aggregated monthly data is saved to a zarr file for further processing.
Args:
grid ("hex", "healpix"): Grid type.
level (int): Grid resolution level.
"""
daily = daily_enrich(grid, level)
assert {"cell", "time"} == set(daily.dims), f"Expected dims ('cell', 'time'), got {daily.dims}"
# Monthly aggregates
monthly = xr.merge(
[
# Original variables
daily.t2m_min.resample(time="1ME").min().rename("t2m_min"),
daily.t2m_max.resample(time="1ME").max().rename("t2m_max"),
daily.snowc_mean.resample(time="1ME").mean().rename("snowc_mean"),
daily.sde_mean.resample(time="1ME").mean().rename("sde_mean"),
daily.lblt_max.resample(time="1ME").max().rename("lblt_max"),
daily.tp.resample(time="1ME").sum().rename("tp"),
daily.sf.resample(time="1ME").sum().rename("sf"),
daily.sshf.resample(time="1ME").sum().rename("sshf"),
# Enriched variables
daily.t2m_avg.resample(time="1ME").mean().rename("t2m_avg"),
daily.t2m_range.resample(time="1ME").mean().rename("t2m_mean_range"),
daily.t2m_skew.resample(time="1ME").mean().rename("t2m_mean_skew"),
daily.thawing_degree_days.resample(time="1ME").sum().rename("thawing_degree_days"),
daily.freezing_degree_days.resample(time="1ME").sum().rename("freezing_degree_days"),
daily.thawing_days.resample(time="1ME").sum().rename("thawing_days"),
daily.freezing_days.resample(time="1ME").sum().rename("freezing_days"),
daily.precipitation_occurrences.resample(time="1ME").sum().rename("precipitation_occurrences"),
daily.snowfall_occurrences.resample(time="1ME").sum().rename("snowfall_occurrences"),
daily.snow_isolation.resample(time="1ME").mean().rename("snow_mean_isolation"),
]
)
monthly_grid_path = _get_grid_paths("monthly", grid, level)
monthly.to_zarr(monthly_grid_path, mode="w", encoding=create_encoding(monthly), consolidated=False)
def yearly_aggregate(monthly: xr.Dataset) -> xr.Dataset:
"""Aggregate monthly ERA5 data to yearly resolution.
Takes monthly aggregated data and creates yearly aggregates using a shifted
calendar (October to September) to better capture Arctic seasonal patterns.
Args:
monthly (xr.Dataset): The monthly aggregates
Returns:
xr.Dataset: The aggregated dataset
"""
return xr.merge(
[
# Original variables
monthly.t2m_min.resample(time="1YE").min().rename("t2m_min"),
monthly.t2m_max.resample(time="1YE").max().rename("t2m_max"),
monthly.snowc_mean.resample(time="1YE").mean().rename("snowc_mean"),
monthly.sde_mean.resample(time="1YE").mean().rename("sde_mean"),
monthly.lblt_max.resample(time="1YE").max().rename("lblt_max"),
monthly.tp.resample(time="1YE").sum().rename("tp"),
monthly.sf.resample(time="1YE").sum().rename("sf"),
monthly.sshf.resample(time="1YE").sum().rename("sshf"),
# Enriched variables
monthly.t2m_avg.resample(time="1YE").mean().rename("t2m_avg"),
# TODO: Check if this is correct -> use daily / hourly data instead for range and skew?
monthly.t2m_mean_range.resample(time="1YE").mean().rename("t2m_mean_range"),
monthly.t2m_mean_skew.resample(time="1YE").mean().rename("t2m_mean_skew"),
monthly.thawing_degree_days.resample(time="1YE").sum().rename("thawing_degree_days"),
monthly.freezing_degree_days.resample(time="1YE").sum().rename("freezing_degree_days"),
monthly.thawing_days.resample(time="1YE").sum().rename("thawing_days"),
monthly.freezing_days.resample(time="1YE").sum().rename("freezing_days"),
monthly.precipitation_occurrences.resample(time="1YE").sum().rename("precipitation_occurrences"),
monthly.snowfall_occurrences.resample(time="1YE").sum().rename("snowfall_occurrences"),
monthly.snow_mean_isolation.resample(time="1YE").mean().rename("snow_mean_isolation"),
]
)
def yearly_and_seasonal_aggregate(grid: Literal["hex", "healpix"], level: int):
"""Aggregate monthly ERA5 data to yearly resolution with seasonal splits.
Takes monthly aggregated data and creates yearly aggregates using a shifted
calendar (October to September) to better capture Arctic seasonal patterns.
Creates separate aggregates for full year, winter (Oct-Apr), and summer
(May-Sep) periods.
The first and last incomplete years are excluded from the analysis.
Winter months are defined as months 1-7 in the shifted calendar,
and summer months are 8-12.
The final dataset includes yearly, winter, and summer aggregates for all
climate variables, saved to a zarr file.
Args:
grid ("hex", "healpix"): Grid type.
level (int): Grid resolution level.
"""
monthly_grid_path = _get_grid_paths("monthly", grid, level)
monthly = xr.open_zarr(monthly_grid_path, consolidated=False).set_coords("spatial_ref")
assert {"cell", "time"} == set(monthly.dims), f"Expected dims ('cell', 'time'), got {monthly.dims}"
valid_years = slice(str(monthly.time.min().dt.year.item() + 1), str(monthly.time.max().dt.year.item()))
# Summer aggregates
summer = yearly_aggregate(monthly.sel(time=monthly.time.dt.month.isin([5, 6, 7, 8, 9])).sel(time=valid_years))
# Yearly aggregates (shifted by +8 months to start in Oktober, first and last years will be cropped)
monthly_shifted = monthly.copy()
monthly_shifted["time"] = monthly_shifted.get_index("time") + pd.DateOffset(months=8)
monthly_shifted = monthly_shifted.sel(time=valid_years)
yearly = yearly_aggregate(monthly_shifted)
# Winter aggregates (shifted by +8 months to start in Oktober, first and last years will be cropped)
monthly_shifted = monthly.copy().sel(time=monthly.time.dt.month.isin([1, 2, 3, 4, 10, 11, 12]))
monthly_shifted["time"] = monthly_shifted.get_index("time") + pd.DateOffset(months=8)
monthly_shifted = monthly_shifted.sel(time=valid_years)
winter = yearly_aggregate(monthly_shifted)
yearly_grid_path = _get_grid_paths("yearly", grid, level)
yearly.to_zarr(yearly_grid_path, mode="w", encoding=create_encoding(yearly), consolidated=False)
winter_grid_path = _get_grid_paths("winter", grid, level)
winter.to_zarr(winter_grid_path, mode="w", encoding=create_encoding(winter), consolidated=False)
summer_grid_path = _get_grid_paths("summer", grid, level)
summer.to_zarr(summer_grid_path, mode="w", encoding=create_encoding(summer), consolidated=False)
@cli.command
def temporal_agg(n_workers: int = 10):
"""Perform temporal aggregation of ERA5 data using Dask cluster.
Creates a Dask cluster and runs both monthly and yearly aggregation
functions to generate temporally aggregated climate datasets. The
processing uses parallel workers for efficient computation.
Args:
n_workers (int, optional): Number of Dask workers to use. Defaults to 10.
"""
with (
dd.LocalCluster(n_workers=n_workers, threads_per_worker=20, memory_limit="10GB") as cluster,
dd.Client(cluster) as client,
):
print(client)
print(client.dashboard_link)
monthly_aggregate()
yearly_and_seasonal_aggregate()
print("Enriched ERA5 data with additional features and aggregated it temporally.")
if __name__ == "__main__":
cli()