From f01687f187ef3794b599a788645db3cd55e69c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Fri, 26 Sep 2025 11:05:29 +0200 Subject: [PATCH] Create grid and era5 download --- create_grid.py | 231 +++++++++++++++++++++++++++++ era5.py | 390 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 621 insertions(+) create mode 100644 create_grid.py create mode 100644 era5.py diff --git a/create_grid.py b/create_grid.py new file mode 100644 index 0000000..0cf6b1b --- /dev/null +++ b/create_grid.py @@ -0,0 +1,231 @@ +"""Create a global hexagonal grid using H3. + +Author: Tobias Hölzer +Date: 09. June 2025 +""" + +from typing import Literal + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import cyclopts +import geopandas as gpd +import h3 +import matplotlib.path as mpath +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +import xdggs +import xvec # noqa: F401 +from rich import pretty, print, traceback +from shapely.geometry import Polygon +from shapely.ops import transform +from stopuhr import stopwatch +from xdggs.healpix import HealpixInfo + +traceback.install() +pretty.install() + + +@stopwatch("Create a global hex grid") +def create_global_hex_grid(resolution): + """Create a global hexagonal grid using H3. + + Args: + resolution (int): H3 resolution level (0-15) + + Returns: + GeoDataFrame: Global hexagonal grid + + """ + # Generate hexagons + hex0_cells = h3.get_res0_cells() + + if resolution > 0: + hex_cells = [] + for hex0_cell in hex0_cells: + hex_cells.extend(h3.cell_to_children(hex0_cell, resolution)) + + else: + hex_cells = hex0_cells + + # Initialize lists to store hex information + hex_list = [] + hex_id_list = [] + hex_area_list = [] + + # Convert each hex ID to a polygon + for hex_id in hex_cells: + boundary_coords = h3.cell_to_boundary(hex_id) + hex_polygon = Polygon(boundary_coords) + hex_polygon = transform(lambda x, y: (y, x), hex_polygon) # Convert from (lat, lon) to (lon, lat) + hex_area = h3.cell_area(hex_id, unit="km^2") + hex_list.append(hex_polygon) + hex_id_list.append(hex_id) + hex_area_list.append(hex_area) + + # Create GeoDataFrame + grid = gpd.GeoDataFrame({"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326") + + return grid + + +@stopwatch("Create a global HEALPix grid") +def create_global_healpix_grid(level: int): + """Create a global HEALPix grid. + + Args: + level (int): HEALPix level (0-12) + + Returns: + GeoDataFrame: Global HEALPix grid + + """ + grid_info = HealpixInfo(level=level, indexing_scheme="nested") + healpix_ds = xr.Dataset( + coords={ + "cell_ids": ( + "cells", + np.arange(12 * 4**grid_info.level), + grid_info.to_dict(), + ) + } + ).pipe(xdggs.decode) + + cell_ids = healpix_ds.cell_ids.values + geometry = healpix_ds.dggs.cell_boundaries() + + # Create GeoDataFrame + grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") + + grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2 + + return grid + + +@stopwatch("Filter grid") +def filter_permafrost_grid(grid: gpd.GeoDataFrame): + """Filter an existing grid to permafrost extent & remove water. + + Args: + grid (gpd.GeoDataFrame): Input grid + + Returns: + gpd.GeoDataFrame: Filtered grid + + """ + # Filter for Permafrost region (> 50° latitude) + grid = grid[grid.geometry.bounds.miny > 50] + # Filter for Arctic Sea (<85° latitude) + grid = grid[grid.geometry.bounds.maxy < 85] + + # Convert to arctic stereographic projection + grid = grid.to_crs("EPSG:3413") + + # Filter out non-land areas (e.g., oceans) + water_mask = gpd.read_file("./data/simplified-water-polygons-split-3857/simplified_water_polygons.shp") + water_mask = water_mask.to_crs("EPSG:3413") + + ov = gpd.overlay(grid, water_mask, how="intersection") + ov["area"] = ov.geometry.area / 1e6 # Convert to km^2 + ov = ov.groupby("cell_id").agg({"area": "sum"}) + + grid["water_area"] = grid["cell_id"].map(ov.area).fillna(0) + grid["land_area"] = grid["cell_area"] - grid["water_area"] + grid["land_ratio"] = grid["land_area"] / grid["cell_area"] + + # Filter for land areas (> 10% land) + grid = grid[grid["land_ratio"] > 0.1] + + return grid + + +def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure: + """Vizualize the grid on a polar stereographic map. + + Args: + data (gpd.GeoDataFrame): The grid data to visualize. + grid (str): The type of grid (e.g., "hex" or "healpix"). + level (int): The level of the grid. + + Returns: + plt.Figure: The matplotlib figure object. + + """ + fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()}) + ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) + + # Add features + ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") + ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") + ax.add_feature(cfeature.COASTLINE) + ax.add_feature(cfeature.BORDERS, linestyle=":") + ax.add_feature(cfeature.LAKES, alpha=0.5) + ax.add_feature(cfeature.RIVERS) + + # Add gridlines + gl = ax.gridlines(draw_labels=True) + gl.top_labels = False + gl.right_labels = False + + # Plot grid cells, coloring by 'cell_area' + data = data.to_crs("EPSG:4326") + + is_anti_meridian = data.bounds.apply(lambda b: (b.maxx - b.minx) > 180, axis=1) + data = data[~is_anti_meridian] + data.plot( + ax=ax, + column="cell_area", + cmap="viridis", + legend=True, + transform=ccrs.PlateCarree(), + edgecolor="k", + linewidth=0.2, + aspect="equal", + alpha=0.5, + ) + + ax.set_title(f"{grid.capitalize()} grid ({level=})", fontsize=14) + + # Compute a circle in axes coordinates, which we can use as a boundary + # for the map. We can pan/zoom as much as we like - the boundary will be + # permanently circular. + theta = np.linspace(0, 2 * np.pi, 100) + center, radius = [0.5, 0.5], 0.5 + verts = np.vstack([np.sin(theta), np.cos(theta)]).T + circle = mpath.Path(verts * radius + center) + + ax.set_boundary(circle, transform=ax.transAxes) + + return fig + + +def cli(grid: Literal["hex", "healpix"], level: int): + """CLI entry point.""" + print(f"Creating {grid} grid at level {level}...") + if grid == "hex": + grid_gdf = create_global_hex_grid(level) + elif grid == "healpix": + grid_gdf = create_global_healpix_grid(level) + else: + print(f"Unknown grid type: {grid}") + return + + grid_gdf = filter_permafrost_grid(grid_gdf) + print(f"Number of cells at level {level}: {len(grid_gdf)}") + + if not len(grid_gdf): + print("No valid grid cells found.") + return + + grid_gdf.to_parquet(f"./data/grids/permafrost_{grid}{level}_grid.parquet") + print(f"Saved to ./data/grids/permafrost_{grid}{level}_grid.parquet") + + fig = vizualize_grid(grid_gdf, grid, level) + fig.savefig(f"./figures/permafrost_{grid}{level}_grid.png", dpi=300) + print(f"Saved figure to ./figures/permafrost_{grid}{level}_grid.png") + plt.close(fig) + + +if __name__ == "__main__": + cyclopts.run(cli) diff --git a/era5.py b/era5.py new file mode 100644 index 0000000..d4b421e --- /dev/null +++ b/era5.py @@ -0,0 +1,390 @@ +"""Download and preprocess ERA5 data. + +Variables of Interest: +- 2 metre temperature (t2m) +- Total precipitation (tp) +- Snow Fall (sf) +- Snow cover (snowc) +- Snow depth (sde) +- Surface sensible heat flux (sshf) +- Lake ice bottom temperature (lblt) + +Aggregations: +- Summer / Winter 20-bin histogram? + +Spatial -> Enrich -> Temporal ? + +Author: Tobias Hölzer +Date: 09. June 2025 +""" + +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Literal + +import cyclopts +import dask.distributed as dd +import geopandas as gpd +import odc.geo +import odc.geo.xr +import pandas as pd +import shapely +import shapely.ops +import xarray as xr +from numcodecs.zarr3 import Blosc +from rich import pretty, print, traceback +from shapely.geometry import LineString, Polygon + +traceback.install(show_locals=True) +pretty.install() + +DATA_DIR = Path("data/era5") +AGG_PATH = DATA_DIR / "era5_agg.zarr" +ALIGNED_PATH = DATA_DIR / "era5_spatial_aligned.zarr" +MONTHLY_PATH = DATA_DIR / "era5_monthly.zarr" +YEARLY_PATH = DATA_DIR / "era5_yearly.zarr" + +min_lat = 50 +max_lat = 85 +min_time = "2022-01-01" +max_time = "2024-12-31" +subset = {"latitude": slice(max_lat, min_lat), "time": slice(min_time, max_time)} + +DATA_DIR = Path("/isipd/projects/p_aicore_pf/tohoel001/era5_thawing_data") + +today = time.strftime("%Y-%m-%d") + + +# TODO: I think it would be better to aggregate via hours instead of days +# Pipeline would be: +# Download hourly data -> Spatially match hourly data -> +# For {daily, monthly, yearly}: +# Enrich -> Aggregate temporally + + +def create_encoding(ds: xr.Dataset): + # encoding = {var: {"compressors": BloscCodec(cname="zlib", clevel=9)} for var in ds.data_vars} + encoding = {var: {"compressors": Blosc(cname="zstd", clevel=9)} for var in [*ds.data_vars, *ds.coords]} + return encoding + + +def download_daily_aggregated(): + era5 = xr.open_dataset( + "https://data.earthdatahub.destine.eu/era5/reanalysis-era5-land-no-antartica-v0.zarr", + storage_options={"client_kwargs": {"trust_env": True}}, + chunks={"latitude": 64 * 4, "longitude": 64 * 4}, + # chunks={}, + engine="zarr", + ).rename({"valid_time": "time"}) + era5 = era5.sel(**subset) + + era5_agg = xr.merge( + [ + era5.t2m.resample(time="1D").max().rename("t2m_daily_max"), + era5.t2m.resample(time="1D").min().rename("t2m_daily_min"), + era5.tp.resample(time="1D").sum().rename("tp_daily_sum"), + # era5.sf.resample(time="1D").sum().rename("sf_daily_sum"), + # era5.snowc.resample(time="1D").mean().rename("snowc_daily_mean"), + # era5.sde.resample(time="1D").mean().rename("sde_daily_mean"), + # era5.sshf.resample(time="1D").sum().rename("sshf_daily_sum"), + # era5.lblt.resample(time="1D").max().rename("lblt_daily_max"), + ] + ) + + # Rechunk if the first time chunk is not the same as the middle ones + if era5_agg.chunksizes["time"][0] != era5_agg.chunksizes["time"][1]: + era5_agg = era5_agg.chunk({"time": 120}) + + # Assign attributes + era5_agg["t2m_daily_max"].attrs = {"long_name": "Daily maximum 2 metre temperature", "units": "K"} + era5_agg["t2m_daily_min"].attrs = {"long_name": "Daily minimum 2 metre temperature", "units": "K"} + era5_agg["tp_daily_sum"].attrs = {"long_name": "Daily total precipitation", "units": "m"} + # era5_agg["sf_daily_sum"].attrs = {"long_name": "Daily total snow fall", "units": "m"} + # era5_agg["snowc_daily_mean"].attrs = {"long_name": "Daily mean snow cover", "units": "m"} + # era5_agg["sde_daily_mean"].attrs = {"long_name": "Daily mean snow depth", "units": "m"} + # era5_agg["sshf_daily_sum"].attrs = {"long_name": "Daily total surface sensible heat flux", "units": "J/m²"} + # era5_agg["lblt_daily_max"].attrs = {"long_name": "Daily maximum lake ice bottom temperature", "units": "K"} + + era5_agg.to_zarr(AGG_PATH, mode="w", encoding=create_encoding(era5_agg), consolidated=False) + + +def crosses_antimeridian(geom: Polygon) -> bool: + coords = shapely.get_coordinates(geom) + crosses_any_meridian = (coords[:, 0] > 0).any() and (coords[:, 0] < 0).any() + return crosses_any_meridian and abs(coords[:, 0]).max() > 90 + + +def split_antimeridian_cell(geom: Polygon) -> list[Polygon]: + # Assumes that it is a antimeridian hex + coords = shapely.get_coordinates(geom) + for i in range(coords.shape[0]): + if coords[i, 0] < 0: + coords[i, 0] += 360 + geom = Polygon(coords) + antimeridian = LineString([[180, -90], [180, 90]]) + polys = shapely.ops.split(geom, antimeridian) + return list(polys.geoms) + + +def check_geobox(geobox): + x, y = geobox.shape + return x > 1 and y > 1 + + +def extract_cell_data(idx: int, geom: Polygon) -> xr.Dataset: + era5_agg = xr.open_zarr(AGG_PATH) + assert {"latitude", "longitude", "time"} == set(era5_agg.dims), ( + f"Expected dims ('latitude', 'longitude', 'time'), got {era5_agg.dims}" + ) + # cell.geometry is a shapely Polygon + if not crosses_antimeridian(geom): + geoms = [geom] + # Split geometry in case it crossed antimeridian + else: + geoms = split_antimeridian_cell(geom) + cell_data = [] + for geom in geoms: + geom = odc.geo.Geometry(geom, crs="epsg:4326") + if not check_geobox(era5_agg.odc.geobox.enclosing(geom)): + continue + cell_data.append(era5_agg.odc.crop(geom).drop_vars("spatial_ref").mean(["latitude", "longitude"])) + if len(cell_data) == 0: + return None + elif len(cell_data) == 1: + return cell_data[0].expand_dims({"cell": [idx]}).chunk({"cell": 1}) + else: + return xr.concat(cell_data, dim="part").mean("part").expand_dims({"cell": [idx]}).chunk({"cell": 1}) + + +def spatial_matching(grid: gpd.GeoDataFrame, n_workers: int = 10): + with ThreadPoolExecutor(max_workers=n_workers) as executor: + futures = { + executor.submit(extract_cell_data, idx, row.geometry): idx + for idx, row in grid.to_crs("epsg:4326").iterrows() + } + for future in as_completed(futures): + idx = futures[future] + try: + data = future.result() + data.to_zarr(ALIGNED_PATH, append_dim="cell", consolidated=False, encoding=create_encoding(data)) + except Exception as e: + print(f"Error processing cell {idx}: {e}") + + +def daily_enrich() -> xr.Dataset: + era5 = xr.open_zarr(ALIGNED_PATH) + assert {"cell", "time"} == set(era5.dims), f"Expected dims ('cell', 'time'), got {era5.dims}" + + # Formulas based on Groeke et. al. (2025) Stochastic Weather generation... + era5["t2m_daily_avg"] = (era5.t2m_daily_max + era5.t2m_daily_min) / 2 + era5.t2m_daily_avg.attrs = {"long_name": "Daily average 2 metre temperature", "units": "K"} + era5["t2m_daily_range"] = era5.t2m_daily_max - era5.t2m_daily_min + era5.t2m_daily_range.attrs = {"long_name": "Daily range of 2 metre temperature", "units": "K"} + era5["t2m_daily_skew"] = (era5.t2m_daily_avg - era5.t2m_daily_min) / era5.t2m_daily_range + era5.t2m_daily_skew.attrs = {"long_name": "Daily skewness of 2 metre temperature"} + + era5["thawing_degree_days"] = (era5.t2m_daily_avg - 273.15).clip(min=0) + era5.thawing_degree_days.attrs = {"long_name": "Thawing degree days", "units": "K"} + era5["freezing_degree_days"] = (273.15 - era5.t2m_daily_avg).clip(min=0) + era5.freezing_degree_days.attrs = {"long_name": "Freezing degree days", "units": "K"} + + era5["thawing_days"] = (era5.t2m_daily_avg > 273.15).astype(int) + era5.thawing_days.attrs = {"long_name": "Thawing days"} + era5["freezing_days"] = (era5.t2m_daily_avg < 273.15).astype(int) + era5.freezing_days.attrs = {"long_name": "Freezing days"} + + era5["precipitation_occurrences"] = (era5.tp_daily_sum > 0).astype(int) + era5.precipitation_occurrences.attrs = {"long_name": "Precipitation occurrences"} + era5["snowfall_occurrences"] = (era5.sf_daily_sum > 0).astype(int) + era5.snowfall_occurrences.attrs = {"long_name": "Snowfall occurrences"} + + era5["snow_isolation"] = era5.snowc_daily_mean * era5.sde_daily_mean + era5.snow_isolation.attrs = {"long_name": "Snow isolation"} + + return era5 + + +def monthly_aggregate(): + era5 = daily_enrich() + assert {"cell", "time"} == set(era5.dims), f"Expected dims ('cell', 'time'), got {era5.dims}" + + # Monthly aggregates + monthly = xr.merge( + [ + # Original variables + era5.t2m_daily_min.resample(time="1M").min().rename("t2m_monthly_min"), + era5.t2m_daily_max.resample(time="1M").max().rename("t2m_monthly_max"), + era5.tp_daily_sum.resample(time="1M").sum().rename("tp_monthly_sum"), + era5.sf_daily_sum.resample(time="1M").sum().rename("sf_monthly_sum"), + era5.snowc_daily_mean.resample(time="1M").mean().rename("snowc_monthly_mean"), + era5.sde_daily_mean.resample(time="1M").mean().rename("sde_monthly_mean"), + era5.sshf_daily_sum.resample(time="1M").sum().rename("sshf_monthly_sum"), + era5.lblt_daily_max.resample(time="1M").max().rename("lblt_monthly_max"), + # Enriched variables + era5.t2m_daily_avg.resample(time="1M").mean().rename("t2m_monthly_avg"), + era5.t2m_daily_range.resample(time="1M").mean().rename("t2m_daily_range_monthly_avg"), + era5.t2m_daily_skew.resample(time="1M").mean().rename("t2m_daily_skew_monthly_avg"), + era5.thawing_degree_days.resample(time="1M").sum().rename("thawing_degree_days_monthly"), + era5.freezing_degree_days.resample(time="1M").sum().rename("freezing_degree_days_monthly"), + era5.thawing_days.resample(time="1M").sum().rename("thawing_days_monthly"), + era5.freezing_days.resample(time="1M").sum().rename("freezing_days_monthly"), + era5.precipitation_occurrences.resample(time="1M").sum().rename("precipitation_occurrences_monthly"), + era5.snowfall_occurrences.resample(time="1M").sum().rename("snowfall_occurrences_monthly"), + era5.snow_isolation.resample(time="1M").mean().rename("snow_isolation_monthly_mean"), + ] + ) + monthly.to_zarr(MONTHLY_PATH, mode="w", encoding=create_encoding(monthly), consolidated=False) + + +def yearly_aggregate(): + monthly = xr.open_zarr(MONTHLY_PATH) + assert {"cell", "time"} == set(monthly.dims), f"Expected dims ('cell', 'time'), got {monthly.dims}" + + # Yearly aggregates (shifted by +10 months to start in Oktober, first and last years will be cropped) + monthly_shifted = monthly.copy() + monthly_shifted["time"] = monthly_shifted.get_index("time") + pd.DateOffset(months=10) + incomplete_years = {monthly_shifted.time.dt.year.min().item(), monthly_shifted.time.dt.year.max().item()} + monthly_shifted = monthly_shifted.sel(time=~monthly_shifted.time.dt.year.isin(incomplete_years)) + yearly = xr.merge( + [ + # Original variables + monthly_shifted.t2m_monthly_min.resample(time="1Y").min().rename("t2m_yearly_min"), + monthly_shifted.t2m_monthly_max.resample(time="1Y").max().rename("t2m_yearly_max"), + monthly_shifted.tp_monthly_sum.resample(time="1Y").sum().rename("tp_yearly_sum"), + monthly_shifted.sf_monthly_sum.resample(time="1Y").sum().rename("sf_yearly_sum"), + monthly_shifted.snowc_monthly_mean.resample(time="1Y").mean().rename("snowc_yearly_mean"), + monthly_shifted.sde_monthly_mean.resample(time="1Y").mean().rename("sde_yearly_mean"), + monthly_shifted.sshf_monthly_sum.resample(time="1Y").sum().rename("sshf_yearly_sum"), + monthly_shifted.lblt_monthly_max.resample(time="1Y").max().rename("lblt_yearly_max"), + # Enriched variables + monthly_shifted.t2m_monthly_avg.resample(time="1Y").mean().rename("t2m_yearly_avg"), + # TODO: Check if this is correct -> use daily / hourly data instead for range and skew? + monthly_shifted.t2m_monthly_range.resample(time="1Y").mean().rename("t2m_daily_range_yearly_avg"), + monthly_shifted.t2m_monthly_skew.resample(time="1Y").mean().rename("t2m_daily_skew_yearly_avg"), + monthly_shifted.thawing_degree_days_monthly.resample(time="1Y").sum().rename("thawing_degree_days_yearly"), + monthly_shifted.freezing_degree_days_monthly.resample(time="1Y") + .sum() + .rename("freezing_degree_days_yearly"), + monthly_shifted.thawing_days_monthly.resample(time="1Y").sum().rename("thawing_days_yearly"), + monthly_shifted.freezing_days_monthly.resample(time="1Y").sum().rename("freezing_days_yearly"), + monthly_shifted.precipitation_occurrences_monthly.resample(time="1Y") + .sum() + .rename("precipitation_occurrences_yearly"), + monthly_shifted.snowfall_occurrences_monthly.resample(time="1Y") + .sum() + .rename("snowfall_occurrences_yearly"), + monthly_shifted.snow_isolation_monthly_mean.resample(time="1Y").mean().rename("snow_isolation_yearly_mean"), + ] + ) + # Summer / Winter aggregates + winter_months = [1, 2, 3, 4, 5, 6, 7] # These do NOT correspond to calendar months, but to the shifted months + summer_months = [8, 9, 10, 11, 12] + monthly_shifted_winter = monthly_shifted.sel(time=monthly_shifted.time.dt.month.isin(winter_months)) + monthly_shifted_summer = monthly_shifted.sel(time=monthly_shifted.time.dt.month.isin(summer_months)) + + winter = xr.merge( + [ + # Original variables + monthly_shifted_winter.t2m_monthly_min.resample(time="1Y").min().rename("t2m_winter_min"), + monthly_shifted_winter.t2m_monthly_max.resample(time="1Y").max().rename("t2m_winter_max"), + monthly_shifted_winter.tp_monthly_sum.resample(time="1Y").sum().rename("tp_winter_sum"), + monthly_shifted_winter.sf_monthly_sum.resample(time="1Y").sum().rename("sf_winter_sum"), + monthly_shifted_winter.snowc_monthly_mean.resample(time="1Y").mean().rename("snowc_winter_mean"), + monthly_shifted_winter.sde_monthly_mean.resample(time="1Y").mean().rename("sde_winter_mean"), + monthly_shifted_winter.sshf_monthly_sum.resample(time="1Y").sum().rename("sshf_winter_sum"), + monthly_shifted_winter.lblt_monthly_max.resample(time="1Y").max().rename("lblt_winter_max"), + # Enriched variables + monthly_shifted_winter.t2m_monthly_avg.resample(time="1Y").mean().rename("t2m_winter_avg"), + # TODO: Check if this is correct -> use daily / hourly data instead for range and skew? + monthly_shifted_winter.t2m_monthly_range.resample(time="1Y").mean().rename("t2m_daily_range_winter_avg"), + monthly_shifted_winter.t2m_monthly_skew.resample(time="1Y").mean().rename("t2m_daily_skew_winter_avg"), + monthly_shifted_winter.thawing_degree_days_monthly.resample(time="1Y") + .sum() + .rename("thawing_degree_days_winter"), + monthly_shifted_winter.freezing_degree_days_monthly.resample(time="1Y") + .sum() + .rename("freezing_degree_days_winter"), + monthly_shifted_winter.thawing_days_monthly.resample(time="1Y").sum().rename("thawing_days_winter"), + monthly_shifted_winter.freezing_days_monthly.resample(time="1Y").sum().rename("freezing_days_winter"), + monthly_shifted_winter.precipitation_occurrences_monthly.resample(time="1Y") + .sum() + .rename("precipitation_occurrences_winter"), + monthly_shifted_winter.snowfall_occurrences_monthly.resample(time="1Y") + .sum() + .rename("snowfall_occurrences_winter"), + monthly_shifted_winter.snow_isolation_monthly_mean.resample(time="1Y") + .mean() + .rename("snow_isolation_winter_mean"), + ] + ) + + summer = xr.merge( + [ + # Original variables + monthly_shifted_summer.t2m_monthly_min.resample(time="1Y").min().rename("t2m_summer_min"), + monthly_shifted_summer.t2m_monthly_max.resample(time="1Y").max().rename("t2m_summer_max"), + monthly_shifted_summer.tp_monthly_sum.resample(time="1Y").sum().rename("tp_summer_sum"), + monthly_shifted_summer.sf_monthly_sum.resample(time="1Y").sum().rename("sf_summer_sum"), + monthly_shifted_summer.snowc_monthly_mean.resample(time="1Y").mean().rename("snowc_summer_mean"), + monthly_shifted_summer.sde_monthly_mean.resample(time="1Y").mean().rename("sde_summer_mean"), + monthly_shifted_summer.sshf_monthly_sum.resample(time="1Y").sum().rename("sshf_summer_sum"), + monthly_shifted_summer.lblt_monthly_max.resample(time="1Y").max().rename("lblt_summer_max"), + # Enriched variables + monthly_shifted_summer.t2m_monthly_avg.resample(time="1Y").mean().rename("t2m_summer_avg"), + # TODO: Check if this is correct -> use daily / hourly data instead for range and skew? + monthly_shifted_summer.t2m_monthly_range.resample(time="1Y").mean().rename("t2m_daily_range_summer_avg"), + monthly_shifted_summer.t2m_monthly_skew.resample(time="1Y").mean().rename("t2m_daily_skew_summer_avg"), + monthly_shifted_summer.thawing_degree_days_summer.resample(time="1Y") + .sum() + .rename("thawing_degree_days_summer"), + monthly_shifted_summer.freezing_degree_days_summer.resample(time="1Y") + .sum() + .rename("freezing_degree_days_summer"), + monthly_shifted_summer.thawing_days_summer.resample(time="1Y").sum().rename("thawing_days_summer"), + monthly_shifted_summer.freezing_days_summer.resample(time="1Y").sum().rename("freezing_days_summer"), + monthly_shifted_summer.precipitation_occurrences_summer.resample(time="1Y") + .sum() + .rename("precipitation_occurrences_summer"), + monthly_shifted_summer.snowfall_occurrences_summer.resample(time="1Y") + .sum() + .rename("snowfall_occurrences_summer"), + monthly_shifted_summer.snow_isolation_summer.resample(time="1Y") + .mean() + .rename("snow_isolation_summer_mean"), + ] + ) + + combined = xr.merge([yearly, summer, winter]) + combined.to_zarr(YEARLY_PATH, mode="w", encoding=create_encoding(combined), consolidated=False) + + +def cli(grid: Literal["hex", "healpix"], level: int, download: bool = False, n_workers: int = 10): + """Run the CLI for ERA5 data processing. + + Args: + grid (Literal["hex", "healpix"]): The grid type to use. + level (int): The processing level. + download (bool, optional): Whether to download data. Defaults to False. + n_workers (int, optional): Number of workers for parallel processing. Defaults to 10. + + """ + cluster = dd.LocalCluster(n_workers=n_workers, threads_per_worker=4, memory_limit="20GB") + client = dd.Client(cluster) + print(client) + print(client.dashboard_link) + + if download: + download_daily_aggregated() + print("Downloaded and aggregated ERA5 data.") + + grid = gpd.read_parquet(DATA_DIR / f"grids/permafrost_{grid}{level}_grid.parquet") + spatial_matching(grid, n_workers=n_workers) + print("Spatially matched ERA5 data to grid.") + monthly_aggregate() + yearly_aggregate() + print("Enriched ERA5 data with additional features and aggregated it temporally.") + + +if __name__ == "__main__": + cyclopts.run(cli)