Refactor dataset ensemble to allow different temporal modes

This commit is contained in:
Tobias Hölzer 2026-01-11 15:57:14 +01:00
parent 1495f71ac9
commit 231caa62e7
6 changed files with 708 additions and 505 deletions

View file

@ -1,15 +1,27 @@
#! /bin/bash
# pixi shell
darts extract-darts-rts --grid hex --level 3
darts extract-darts-rts --grid hex --level 4
darts extract-darts-rts --grid hex --level 5
darts extract-darts-rts --grid hex --level 6
darts extract-darts-rts --grid healpix --level 6
darts extract-darts-rts --grid healpix --level 7
darts extract-darts-rts --grid healpix --level 8
darts extract-darts-rts --grid healpix --level 9
darts extract-darts-rts --grid healpix --level 10
darts extract-darts-v1 --grid hex --level 3
darts extract-darts-v1 --grid hex --level 4
darts extract-darts-v1 --grid hex --level 5
darts extract-darts-v1 --grid hex --level 6
darts extract-darts-v1 --grid healpix --level 6
darts extract-darts-v1 --grid healpix --level 7
darts extract-darts-v1 --grid healpix --level 8
darts extract-darts-v1 --grid healpix --level 9
darts extract-darts-v1 --grid healpix --level 10
darts extract-darts-v1-aggregated --grid hex --level 3
darts extract-darts-v1-aggregated --grid hex --level 4
darts extract-darts-v1-aggregated --grid hex --level 5
darts extract-darts-v1-aggregated --grid hex --level 6
darts extract-darts-v1-aggregated --grid healpix --level 6
darts extract-darts-v1-aggregated --grid healpix --level 7
darts extract-darts-v1-aggregated --grid healpix --level 8
darts extract-darts-v1-aggregated --grid healpix --level 9
darts extract-darts-v1-aggregated --grid healpix --level 10
darts extract-darts-mllabels --grid hex --level 3
darts extract-darts-mllabels --grid hex --level 4

View file

@ -9,28 +9,140 @@ Date: October 2025
import cyclopts
import geopandas as gpd
import pandas as pd
from rich import pretty, print, traceback
from rich.progress import track
import xarray as xr
import xdggs
from rich import pretty, traceback
from stopuhr import stopwatch
from entropice.spatial import grids
from entropice.utils.paths import (
darts_ml_training_labels_repo,
dartsl2_cov_file,
dartsl2_file,
get_darts_rts_file,
DARTS_MLLABELS_DIR,
DARTS_V1_DIR,
get_darts_file,
)
from entropice.utils.types import Grid
traceback.install()
pretty.install()
darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
darts_ml_training_labels_repo = DARTS_MLLABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
cli = cyclopts.App(name="darts-rts")
def _load_grid(grid: Grid, level: int) -> tuple[gpd.GeoDataFrame, xr.DataArray]:
grid_gdf = grids.open(grid, level)
if grid == "hex":
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16))
cell_areas = xr.DataArray(
grid_gdf["cell_area"].to_numpy(), dims=["cell_id"], coords={"cell_id": grid_gdf["cell_id"].to_numpy()}
)
# We only want to use the geometry and the cell_id
grid_gdf = grid_gdf[["cell_id", "geometry"]]
return grid_gdf, cell_areas
@stopwatch("Convert xdgss")
def _convert_xdggs(darts: xr.Dataset, grid: Grid, level: int) -> xr.Dataset:
darts = darts.rename({"cell_id": "cell_ids"})
darts["cell_ids"] = darts["cell_ids"].astype("uint64")
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
darts.cell_ids.attrs = gridinfo
darts = xdggs.decode(darts)
return darts
@stopwatch("Process RTS grid")
def _process_rts_grid(
rts: gpd.GeoDataFrame,
cov: gpd.GeoDataFrame,
cell_areas: xr.DataArray,
) -> xr.Dataset:
assert "cell_id" in rts.columns, "RTS data must contain cell_id column."
assert "cell_id" in cov.columns, "Coverage data must contain cell_id column."
cov["area_km2"] = cov.geometry.area / 1e6
covered_area: pd.Series = cov.pivot_table(
index="cell_id",
values="area_km2",
aggfunc="sum",
fill_value=0,
)["area_km2"]
covered_area: xr.DataArray = xr.DataArray.from_series(covered_area)
coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id)
rts["area_km2"] = rts.geometry.area / 1e6
rts = rts.pivot_table(
index="cell_id",
values=["geometry", "area_km2"],
fill_value=0,
aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type]
)
rts: xr.Dataset = xr.Dataset.from_dataframe(rts)
darts = xr.merge(
[rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0
)
darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0)
return darts
@stopwatch("Process RTS yearly grid")
def _process_rts_yearly_grid(
rts: gpd.GeoDataFrame,
cov: gpd.GeoDataFrame,
cell_areas: xr.DataArray,
) -> xr.Dataset:
assert "cell_id" in rts.columns, "RTS data must contain cell_id column."
assert "cell_id" in cov.columns, "Coverage data must contain cell_id column."
cov["area_km2"] = cov.geometry.area / 1e6
covered_area: pd.Series = cov.pivot_table( # type: ignore[var-annotated] # noqa: PD010
index="cell_id",
columns="year",
values="area_km2",
aggfunc="sum",
fill_value=0,
).unstack()
covered_area: xr.DataArray = xr.DataArray.from_series(covered_area)
coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id)
rts: pd.DataFrame = (
rts.pivot_table( # type: ignore[var-annotated] # noqa: PD013
index="cell_id",
columns="year",
values=["geometry", "area_km2"],
fill_value=0,
aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type]
)
.stack(future_stack=True)
.rename(columns={"geometry": "count"})
)
rts: xr.Dataset = xr.Dataset.from_dataframe(rts)
darts = xr.merge(
[rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0
)
darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0)
return darts
@cli.command()
def extract_darts_rts(grid: Grid, level: int):
"""Extract RTS labels from DARTS dataset.
def extract_darts_v1(grid: Grid, level: int):
"""Extract RTS labels from DARTS-v1 Level-2 dataset.
Creates a Darts-v1 xarray Dataset on the specified grid and level.
The Dataset contains the following variables:
- count: Number of RTS in the cell
- area_km2: Total area of RTS in the cell (in km^2)
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
- coverage: Fraction of the cell covered by DARTS
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
Since the DARTS-v1 Level-2 dataset contains yearly data, all variables are indexed by year as well.
Thus each variable has dimensions (cell_ids, year).
Args:
grid (Grid): The grid type to use.
@ -38,66 +150,81 @@ def extract_darts_rts(grid: Grid, level: int):
"""
with stopwatch("Load data"):
darts_l2 = gpd.read_parquet(dartsl2_file)
darts_cov_l2 = gpd.read_parquet(dartsl2_cov_file)
grid_gdf = grids.open(grid, level)
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
grid_gdf, cell_areas = _load_grid(grid, level)
with stopwatch("Extract RTS labels"):
with stopwatch("Assign RTS to grid"):
grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
years = list(grid_cov_l2["year"].unique())
for year in track(years, total=len(years), description="Processing years..."):
with stopwatch("Processing RTS", log=False):
subset = grid_l2[grid_l2["year"] == year]
subset_cov = grid_cov_l2[grid_cov_l2["year"] == year]
darts = _process_rts_yearly_grid(grid_l2, grid_cov_l2, cell_areas)
counts = subset.groupby("cell_id").size()
grid_gdf[f"darts_{year}_rts_count"] = grid_gdf.cell_id.map(counts)
darts = _convert_xdggs(darts, grid, level)
output_path = get_darts_file(grid, level, version="v1")
with stopwatch(f"Writing Darts v1 to {output_path}"):
darts.to_zarr(output_path, consolidated=False, mode="w")
areas = subset.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf[f"darts_{year}_rts_area"] = grid_gdf.cell_id.map(areas)
areas_cov = subset_cov.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf[f"darts_{year}_covered_area"] = grid_gdf.cell_id.map(areas_cov)
grid_gdf[f"darts_{year}_coverage"] = grid_gdf[f"darts_{year}_covered_area"] / grid_gdf.geometry.area
@cli.command()
def extract_darts_v1_aggregated(grid: Grid, level: int):
"""Extract RTS labels from DARTS-v1 Level-3 dataset.
grid_gdf[f"darts_{year}_rts_density"] = (
grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"]
)
Creates a Darts-v1 xarray Dataset on the specified grid and level.
The Dataset contains the following variables:
- count: Number of RTS in the cell
- area_km2: Total area of RTS in the cell (in km^2)
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
- coverage: Fraction of the cell covered by DARTS
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
Since the DARTS-v1 Level-2 dataset contains yearly data, the data is dissolved then exploded to obtain Level-3 data.
Thus each variable has only the dimension (cell_ids).
# Apply corrections to NaNs
covered = ~grid_gdf[f"darts_{year}_coverage"].isna()
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna(
0.0
)
grid_gdf.loc[covered, f"darts_{year}_rts_density"] = grid_gdf.loc[
covered, f"darts_{year}_rts_density"
].fillna(0.0)
grid_gdf[f"darts_{year}_has_coverage"] = covered
grid_gdf[f"darts_{year}_has_rts"] = grid_gdf[f"darts_{year}_rts_count"] > 0
Args:
grid (Grid): The grid type to use.
level (int): The grid level to use.
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1)
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1)
"""
with stopwatch("Load data"):
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
grid_gdf, cell_areas = _load_grid(grid, level)
# Remove overlapping labels by dissolving
darts_l2 = darts_l2[["geometry"]].dissolve().explode()
darts_cov_l2 = darts_cov_l2[["geometry"]].dissolve().explode()
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")]
darts_counts = grid_gdf[darts_counts_columns]
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1)
with stopwatch("Extract RTS labels"):
grid_l3 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
grid_cov_l3 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")]
darts_density = grid_gdf[darts_density_columns]
grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1)
output_path = get_darts_rts_file(grid, level)
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")
stopwatch.summary()
darts = _process_rts_grid(grid_l3, grid_cov_l3, cell_areas)
darts = _convert_xdggs(darts, grid, level)
output_path = get_darts_file(grid, level, version="v1-l3")
with stopwatch(f"Writing Darts v1 l3 to {output_path}"):
darts.to_zarr(output_path, consolidated=False, mode="w")
@cli.command()
def extract_darts_mllabels(grid: Grid, level: int):
"""Extract RTS labels from the DARTS-mllabels dataset.
Creates a Darts-mllabels xarray Dataset on the specified grid and level.
The Dataset contains the following variables:
- count: Number of RTS in the cell
- area_km2: Total area of RTS in the cell (in km^2)
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
- coverage: Fraction of the cell covered by DARTS
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
Since the DARTS-mllabels dataset contains not much data, all variables are aggregated over the entire time period.
Thus each variable has only the dimension (cell_ids).
Args:
grid (Grid): The grid type to use.
level (int): The grid level to use.
"""
with stopwatch("Load data"):
grid_gdf = grids.open(grid, level)
grid_gdf, cell_areas = _load_grid(grid, level)
darts_mllabels = (
gpd.GeoDataFrame(
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
@ -118,34 +245,16 @@ def extract_darts_mllabels(grid: Grid, level: int):
.to_crs(grid_gdf.crs)
)
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
with stopwatch("Extract RTS labels"):
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
with stopwatch("Processing RTS"):
counts = grid_mllabels.groupby("cell_id").size()
grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts)
areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas)
areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov)
grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area
grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"]
# Apply corrections to NaNs
covered = ~grid_gdf["dartsml_coverage"].isna()
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
grid_gdf.loc[covered, "dartsml_rts_density"] = grid_gdf.loc[covered, "dartsml_rts_density"].fillna(0.0)
grid_gdf["dartsml_has_coverage"] = covered
grid_gdf["dartsml_has_rts"] = grid_gdf["dartsml_rts_count"] > 0
output_path = get_darts_rts_file(grid, level, labels=True)
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")
stopwatch.summary()
darts = _process_rts_grid(grid_mllabels, grid_cov_mllabels, cell_areas)
darts = _convert_xdggs(darts, grid, level)
output_path = get_darts_file(grid, level, version="mllabels")
with stopwatch(f"Writing Darts v1 to {output_path}"):
darts.to_zarr(output_path, consolidated=False, mode="w")
def main(): # noqa: D103

View file

@ -1,4 +1,4 @@
# ruff: noqa: N806
# ruff: noqa: N806, D105
"""Training dataset preparation and model training.
Naming conventions:
@ -16,9 +16,9 @@ import hashlib
import json
from collections.abc import Generator
from dataclasses import asdict, dataclass, field
from functools import cached_property
from functools import cache, cached_property
from itertools import product
from typing import Literal, TypedDict
from typing import Literal, cast
import cupy as cp
import cyclopts
@ -33,8 +33,9 @@ from sklearn import set_config
from sklearn.model_selection import train_test_split
from stopuhr import stopwatch
import entropice.spatial.grids
import entropice.utils.paths
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode
traceback.install()
pretty.install()
@ -44,42 +45,53 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
covcol: dict[TargetDataset, str] = {
"darts_rts": "darts_has_coverage",
"darts_mllabels": "dartsml_has_coverage",
}
def _unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
# In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
if aggregation == "yearly":
era5 = era5.rename({"time": "year"})
era5.coords["year"] = era5["year"].dt.year
return era5
taskcol: dict[Task, dict[TargetDataset, str]] = {
"binary": {
"darts_rts": "darts_has_rts",
"darts_mllabels": "dartsml_has_rts",
},
"count": {
"darts_rts": "darts_rts_count",
"darts_mllabels": "dartsml_rts_count",
},
"density": {
"darts_rts": "darts_rts_density",
"darts_mllabels": "dartsml_rts_density",
},
}
# Make the time index a MultiIndex of year and month
era5.coords["year"] = era5.time.dt.year
era5.coords["month"] = era5.time.dt.month
era5["time"] = pd.MultiIndex.from_arrays(
[
era5.time.dt.year.values, # noqa: PD011
era5.time.dt.month.values, # noqa: PD011
],
names=("year", "month"),
)
era5 = era5.unstack("time") # noqa: PD010
seasons = {10: "winter", 4: "summer"}
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
month_map = seasons if aggregation == "seasonal" else shoulder_seasons
era5.coords["month"] = era5["month"].to_series().map(month_map)
return era5
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
if temporal == "yearly":
return time_index.year
elif temporal == "seasonal":
seasons = {10: "winter", 4: "summer"}
return time_index.month.map(lambda x: seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
elif temporal == "shoulder":
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
collapsed = ds.to_dataframe()
# Make a dummy row to avoid empty dataframe issues
use_dummy = collapsed.shape[0] == 0
if use_dummy:
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
pivcols = set(collapsed.index.names) - {"cell_ids"}
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
collapsed.columns = ["_".join(v) for v in collapsed.columns]
if use_dummy:
collapsed = collapsed.dropna(how="all")
return collapsed
def _cell_ids_hash(cell_ids: pd.Series) -> str:
sorted_ids = np.sort(cell_ids.to_numpy())
return hashlib.blake2b(sorted_ids.tobytes(), digest_size=8).hexdigest()
def bin_values(
values: pd.Series,
task: Literal["count", "density"],
task: Literal["count_regimes", "density_regimes"],
none_val: float = 0,
) -> pd.Series:
"""Bin values into predefined intervals for different tasks.
@ -89,7 +101,7 @@ def bin_values(
Args:
values (pd.Series): Pandas Series of numerical values to bin.
task (Literal["count", "density"]): Task type - 'count' or 'density'.
task (Literal["count_regimes", "density_regimes"]): Task type - 'count_regimes' or 'density_regimes'.
none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0.
Returns:
@ -100,15 +112,8 @@ def bin_values(
"""
labels_dict = {
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
"density": [
"Empty",
"Very Sparse",
"Sparse",
"Moderate",
"Dense",
"Very Dense",
],
"count_regimes": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
"density_regimes": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
}
labels = labels_dict[task]
@ -138,62 +143,67 @@ def bin_values(
@dataclass(frozen=True, eq=False)
class DatasetLabels:
binned: pd.Series
class SplittedArrays:
"""Small wrapper for train and test arrays."""
train: torch.Tensor | np.ndarray | cp.ndarray
test: torch.Tensor | np.ndarray | cp.ndarray
raw_values: pd.Series
@dataclass(frozen=True, eq=False)
class TrainingSet:
"""Container for the training dataset."""
targets: gpd.GeoDataFrame
features: pd.DataFrame
X: SplittedArrays
y: SplittedArrays
z: pd.Series
split: pd.Series
@cached_property
def intervals(self) -> list[tuple[float, float] | tuple[int, int]]:
def target_intervals(self) -> list[tuple[float, float] | tuple[int, int] | tuple[None, None]]:
"""Calculate the intervals for each target category.
Returns:
list[tuple[float, float] | tuple[int, int] | tuple[None, None]]:
List of (min, max) tuples for each category.
"""
binned = self.targets["y"]
raw = self.targets["z"]
assert binned.dtype.name == "category", "Target labels are not categorical."
# For each category get the min and max values from raw_values
intervals = []
for category in self.binned.cat.categories:
category_mask = self.binned == category
intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] = []
for category in binned.cat.categories:
category_mask = binned == category
if category_mask.sum() == 0:
intervals.append((None, None))
else:
category_raw_values = self.raw_values[category_mask]
category_raw_values = raw[category_mask]
intervals.append((category_raw_values.min(), category_raw_values.max()))
return intervals
@cached_property
def labels(self) -> list[str]:
return list(self.binned.cat.categories)
@dataclass(frozen=True, eq=False)
class DatasetInputs:
data: pd.DataFrame
train: torch.Tensor | np.ndarray | cp.ndarray
test: torch.Tensor | np.ndarray | cp.ndarray
@dataclass(frozen=True)
class CategoricalTrainingDataset:
dataset: gpd.GeoDataFrame
X: DatasetInputs
y: DatasetLabels
z: pd.Series
split: pd.Series
def target_labels(self) -> list[str]:
"""Labels of the target categories."""
binned = self.targets["y"]
assert binned.dtype.name == "category", "Target labels are not categorical."
return list(binned.cat.categories)
def __len__(self):
return len(self.z)
class DatasetStats(TypedDict):
target: str
num_target_samples: int
members: dict[str, dict[str, object]]
total_features: int
@cyclopts.Parameter("*")
@dataclass(frozen=True, kw_only=True)
class DatasetEnsemble:
"""An ensemble of datasets for training and inference."""
grid: Grid
level: int
target: Literal["darts_rts", "darts_mllabels"]
target: TargetDataset
members: list[L2SourceDataset] = field(
default_factory=lambda: [
"AlphaEarth",
@ -203,366 +213,420 @@ class DatasetEnsemble:
"ERA5-shoulder",
]
)
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
variable_filters: dict[str, list[str]] = field(default_factory=dict)
filter_target: str | Literal[False] = False
temporal_mode: TemporalMode = "synopsis"
dimension_filters: dict[L2SourceDataset, dict[str, list]] = field(default_factory=dict)
variable_filters: dict[L2SourceDataset, list[str]] = field(default_factory=dict)
add_lonlat: bool = True
def __post_init__(self):
# Validate filters
for member in self.members:
if member in self.variable_filters:
# This enforces the assumption that the read members always are xarray Datasets and not DataArrays
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
" Variable filter values must be a list with one or more entries."
)
if member in self.dimension_filters:
for dim, values in self.dimension_filters[member].items():
# This enforces the assumption that we do the temporal filtering via the temporal_mode
assert dim not in ["year", "month", "time"], (
f"Invalid dimension filter for {member=}: {dim}"
" Filtering on 'year', 'month' or 'time' is not supported."
)
# This enforces the assumption that there are no empty dimensions in the Dataset
assert isinstance(values, list) and len(values) >= 1, (
f"Invalid dimension filter for {dim=}: {values}"
" Dimension filter values must be a list with one or more entries."
)
def __hash__(self):
return int(self.id(), 16)
@cache
def id(self):
"""Return an unique, stable identifier based on the settings of this class."""
return hashlib.blake2b(
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
digest_size=16,
).hexdigest()
@property
def covcol(self) -> str:
return covcol[self.target]
@cached_property
def cell_ids(self) -> pd.Series:
return self.read_grid()["cell_id"]
def taskcol(self, task: Task) -> str:
return taskcol[task][self.target]
@cached_property
def geometries(self) -> pd.Series:
return self.read_grid()["geometry"]
def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
elif member == "ArcticDEM":
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
else:
raise NotImplementedError(f"Member {member} not implemented.")
# @stopwatch("Reading grid")
def read_grid(self) -> gpd.GeoDataFrame:
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
# Add the lat / lon of the cell centers
if self.add_lonlat:
grid_gdf["lon"] = grid_gdf.geometry.centroid.x
grid_gdf["lat"] = grid_gdf.geometry.centroid.y
# Convert hex cell_id to int
if self.grid == "hex":
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64)
grid_gdf = grid_gdf.set_index("cell_id")
return grid_gdf
# @stopwatch.f("Getting target labels", print_kwargs=["stage"])
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
"""Create a training target labels for a specific task.
The function reads the target dataset, filters it based on coverage,
and prepares the target labels according to the specified task.
This is quite a fast function, with the bottleneck usually being the reading of the zarr store
(milliseconds to few seconds).
This is intended to be used by another training Dataset creation function
and for easy visualization of the targets.
Args:
task (Task): The task.
Returns:
gpd.GeoDataFrame: GeoDataFrame with target labels, geometry and raw-target values,
all indexed by the cell_ids.
Columns: ["geometry", "y", "z"], where "y" is the target label and "z" the raw value.
For categorical tasks, "y" is categorical and "z" is numerical.
For regression tasks, both "y" and "z" are numerical and identical.
"""
match (self.target, self.temporal_mode):
case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False)
case ("darts_v1" | "darts_v2", int()):
version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
case ("darts_mllabels", str()): # Years are not supported
target_store = entropice.utils.paths.get_darts_file(
grid=self.grid, level=self.level, version="mllabels"
)
targets = xr.open_zarr(target_store, consolidated=False)
case _:
raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.")
targets = cast(xr.Dataset, targets)
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
match task:
case "binary":
z = targets["count"].to_series()
y = (z > 0).astype("category").map({False: "No RTS", True: "RTS"})
case "count_regimes":
z = targets["count"].to_series()
y = bin_values(z, task="count_regimes", none_val=0)
case "density_regimes":
z = targets["density"].to_series()
y = bin_values(z, task="density_regimes", none_val=0)
case "count":
z = targets["count"].to_series()
y = z
case "density":
z = targets["density"].to_series()
y = z
case _:
raise NotImplementedError(f"Task {task} not supported.")
cell_ids = targets["cell_ids"].to_series()
geometries = self.geometries.loc[cell_ids]
return gpd.GeoDataFrame(
{
"cell_id": cell_ids,
"geometry": geometries,
"y": y,
"z": z,
}
).set_index("cell_id")
# @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"])
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
When `lazy` is False, the data is actually read into memory, which can take some time (seconds).
Args:
member (L2SourceDataset): The member
cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read.
Defaults to None.
lazy (bool, optional): Whether to not load the data and instead return a lazy dataset. Defaults to False.
Returns:
xr.Dataset: Xarray Dataset with the member data for the given cell IDs.
"""
match member:
case "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder":
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
case "ArcticDEM":
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
case _:
raise NotImplementedError(f"Member {member} not implemented.")
ds = xr.open_zarr(store, consolidated=False)
# Apply variable filters
# Apply variable and dimension filters
if member in self.variable_filters:
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
" Variable filter values must be a list with one or more entries."
)
ds = ds[self.variable_filters[member]]
# Apply dimension filters
if member in self.dimension_filters:
for dim, values in self.dimension_filters[member].items():
assert isinstance(values, list) and len(values) >= 1, (
f"Invalid dimension filter for {dim=}: {values}"
" Dimension filter values must be a list with one or more entries."
)
ds = ds.sel({dim: values})
ds = ds.sel(self.dimension_filters[member])
# Delete all coordinates which are not in the dimension
for coord in ds.coords:
if coord not in ds.dims:
ds = ds.drop_vars(coord)
ds = ds.drop_vars([coord for coord in ds.coords if coord not in ds.dims])
# Only load target cell ids
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values))
if cell_ids is None:
cell_ids = self.cell_ids
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(cell_ids.to_numpy()))
ds = ds.sel(cell_ids=list(intersecting_cell_ids))
# Unstack era5 data if needed
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
ds = _unstack_era5_time(ds, era5_agg)
# Apply the temporal mode
match (member.split("-"), self.temporal_mode):
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
ds_mean = ds.mean(dim="year")
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
ds_trend = ds_trend.rename(
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
)
ds = xr.merge([ds_mean, ds_trend])
case (["ArcticDEM"], "synopsis"):
pass # No temporal dimension
case (_, int() as year):
ds = ds.sel(year=year, drop=True)
case (_, "feature"):
pass
case _:
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
# Actually read data into memory
if not lazy:
ds.load()
return ds
def _read_target(self) -> gpd.GeoDataFrame:
if self.target == "darts_rts":
target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level)
elif self.target == "darts_mllabels":
target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True)
else:
raise NotImplementedError(f"Target {self.target} not implemented.")
targets = gpd.read_parquet(target_store)
def make_features(
self,
cell_ids: pd.Series | None = None,
cache_mode: Literal["none", "read", "overwrite"] = "none",
) -> pd.DataFrame:
"""Create a feature DataFrame for the given temporal task and cell IDs.
# Filter to coverage
if self.filter_target:
targets = targets[targets[self.filter_target]]
# Convert hex cell_id to int
if self.grid == "hex":
targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16))
This function reads all members of the ensemble, prepares their features based on the temporal task,
and combines them into a single DataFrame (no geometry).
It is quite computation intensive (seconds to minutes), depending on the configuration and number of cell IDs.
To speed up repeated calls, a caching mechanism is implemented.
# Add the lat / lon of the cell centers
if self.add_lonlat:
targets["lon"] = targets.geometry.centroid.x
targets["lat"] = targets.geometry.centroid.y
The indented use for this function is solely the creation of training and inference datasets.
For visualization purposes it is recommended to use the `read_member` with `lazy=True` functions directly.
return targets
Args:
cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read.
Defaults to None.
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
Returns:
pd.DataFrame: DataFrame with features for the given cell IDs.
"""
if cell_ids is None:
cell_ids = self.cell_ids
# Caching mechanism
if cache_mode != "none":
cells_hash = _cell_ids_hash(cell_ids)
cache_file = entropice.utils.paths.get_features_cache(ensemble_id=self.id(), cells_hash=cells_hash)
if cache_mode == "read" and cache_file.exists():
with stopwatch("Loading features from cache"):
return pd.read_parquet(cache_file)
# Create features
dataset = self.read_grid().loc[cell_ids].drop(columns=["geometry"])
member_dfs = []
for member in self.members:
match member.split("-"):
case ["ERA5", era5_agg]:
member_dfs.append(self._prep_era5(cell_ids, era5_agg))
case ["AlphaEarth"]:
member_dfs.append(self._prep_embeddings(cell_ids))
case ["ArcticDEM"]:
member_dfs.append(self._prep_arcticdem(cell_ids))
case _:
raise NotImplementedError(f"Member {member} not implemented.")
with stopwatch("Joining datasets"):
dataset = dataset.join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["read", "overwrite"]:
with stopwatch("Saving features to cache"):
dataset.to_parquet(cache_file)
return dataset
# @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"])
def _prep_era5(
self,
targets: gpd.GeoDataFrame,
temporal: Literal["yearly", "seasonal", "shoulder"],
cell_ids: pd.Series,
era5_agg: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame:
era5 = self._read_member("ERA5-" + temporal, targets)
if len(era5["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
variables = list(era5.data_vars)
times = era5.coords["time"].to_numpy()
time_df = pd.DataFrame({"time": times})
time_df.index = pd.DatetimeIndex(times)
tempus = _get_era5_tempus(time_df, temporal)
unique_tempus = tempus.unique()
if "aggregations" in era5.dims:
aggs_list = era5.coords["aggregations"].to_numpy()
expected_cols = [
f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list)
]
else:
expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
era5_df = era5.to_dataframe()
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
if "aggregations" not in era5.dims:
era5_df = era5_df.pivot_table(index="cell_ids", columns="t")
era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns]
else:
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
era5_df = _collapse_to_dataframe(era5)
era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan)
era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return era5_df
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
if len(embeddings["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
years = embeddings.coords["year"].to_numpy()
aggs = embeddings.coords["agg"].to_numpy()
bands = embeddings.coords["band"].to_numpy()
expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
# @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"])
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
embeddings_df = _collapse_to_dataframe(embeddings)
embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan)
embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return embeddings_df
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
arcticdem = self._read_member("ArcticDEM", targets)
# @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"])
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
if len(arcticdem["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
variables = list(arcticdem.data_vars)
aggs = arcticdem.coords["aggregations"].to_numpy()
expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
return pd.DataFrame(index=cell_ids.to_numpy(), columns=expected_cols, dtype=float)
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN
arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan)
arcticdem_df = arcticdem_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return arcticdem_df
def get_stats(self) -> DatasetStats:
"""Get dataset statistics.
def create_inference_df(
self,
batch_size: int | None = None,
cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> Generator[pd.DataFrame]:
"""Create an inference feature set generator.
Returns:
DatasetStats: Dictionary containing target stats, member stats, and total features count.
This function creates features for all cell IDs in batches.
If `batch_size` is None or greater than the number of cell IDs, all data is returned in a single batch.
Args:
batch_size (int | None, optional): The batch size. If None, all data is returned in a single batch.
Defaults to None.
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
Yields:
Generator[pd.DataFrame]: Generator yielding DataFrames with features for inference.
"""
targets = self._read_target()
stats: DatasetStats = {
"target": self.target,
"num_target_samples": len(targets),
"members": {},
"total_features": 2 if self.add_lonlat else 0, # Lat and Lon
}
for member in self.members:
ds = self._read_member(member, targets, lazy=True)
n_cols_member = len(ds.data_vars)
for dim in ds.sizes:
if dim != "cell_ids":
n_cols_member *= ds.sizes[dim]
stats["members"][member] = {
"variables": list(ds.data_vars),
"num_variables": len(ds.data_vars),
"dimensions": dict(ds.sizes),
"coordinates": list(ds.coords),
"num_features": n_cols_member,
}
stats["total_features"] += n_cols_member
return stats
def print_stats(self):
stats = self.get_stats()
print(f"=== Target: {stats['target']}")
print(f"\tNumber of target samples: {stats['num_target_samples']}")
for member, member_stats in stats["members"].items():
print(f"=== Member: {member}")
print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}")
print(f"\tDimensions: {member_stats['dimensions']}")
print(f"\tCoordinates: {member_stats['coordinates']}")
print(f"\tNumber of features from member: {member_stats['num_features']}")
print(f"=== Total number of features in dataset: {stats['total_features']}")
def create(
self,
filter_target_col: str | None = None,
cache_mode: Literal["n", "o", "r"] = "r",
) -> gpd.GeoDataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
if cache_mode == "r" and cache_file.exists():
with stopwatch("Loading dataset from cache"):
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
return dataset
with stopwatch("Reading target"):
targets = self._read_target()
if filter_target_col is not None:
targets = targets.loc[targets[filter_target_col]]
print(f"Read and filtered target dataset. ({len(targets)} samples)")
with stopwatch("Preparing member datasets"):
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
member_dfs.append(self._prep_era5(targets, era5_agg))
elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets))
elif member == "ArcticDEM":
member_dfs.append(self._prep_arcticdem(targets))
else:
raise NotImplementedError(f"Member {member} not implemented.")
print("Prepared all member datasets. Joining...")
with stopwatch("Joining datasets"):
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
print("Joining complete.")
if cache_mode in ["o", "r"]:
with stopwatch("Saving dataset to cache"):
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
return dataset
def create_batches(
self,
batch_size: int,
filter_target_col: str | None = None,
cache_mode: Literal["n", "o", "r"] = "r",
) -> Generator[pd.DataFrame]:
targets = self._read_target()
if len(targets) == 0:
raise ValueError("No target samples found.")
elif len(targets) < batch_size:
yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode)
all_cell_ids = self.cell_ids
if batch_size is None or batch_size >= len(all_cell_ids):
yield self.make_features(cell_ids=all_cell_ids, cache_mode=cache_mode)
return
if filter_target_col is not None:
targets = targets.loc[targets[filter_target_col]]
for i in range(0, len(all_cell_ids), batch_size):
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
for i in range(0, len(targets), batch_size):
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(
self.id(), subset=filter_target_col, batch=(i, i + batch_size)
)
if cache_mode == "r" and cache_file.exists():
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
yield dataset
else:
targets_batch = targets.iloc[i : i + batch_size]
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
member_dfs.append(self._prep_era5(targets_batch, era5_agg))
elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets_batch))
elif member == "ArcticDEM":
member_dfs.append(self._prep_arcticdem(targets_batch))
else:
raise NotImplementedError(f"Member {member} not implemented.")
dataset = targets_batch.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["o", "r"]:
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
yield dataset
def _cat_and_split(
def create_training_df(
self,
dataset: gpd.GeoDataFrame,
task: Task,
device: Literal["cpu", "cuda", "torch"],
) -> CategoricalTrainingDataset:
taskcol = self.taskcol(task)
cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> pd.DataFrame:
"""Create a simple DataFrame for training with e.g. autogluon.
valid_labels = dataset[taskcol].notna()
This function creates a DataFrame with features and target labels (called "y") for training.
The resulting dataframe contains no geometry information and drops all samples with NaN values.
Does not split into train and test sets, also does not convert to arrays.
For more advanced use cases, use `create_training_set`.
cols_to_drop = {"geometry", taskcol, self.covcol}
cols_to_drop |= {
col
for col in dataset.columns
if col.startswith("dartsml_" if self.target == "darts_mllabels" else "darts_")
}
Args:
task (Task): The task.
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
model_inputs = dataset.drop(columns=cols_to_drop)
# Assert that no column in all-nan
assert not model_inputs.isna().all("index").any(), "Some input columns are all NaN"
# Get valid inputs (rows)
valid_inputs = model_inputs.notna().all("columns")
Returns:
pd.DataFrame: The training DataFrame.
dataset = dataset.loc[valid_labels & valid_inputs]
model_inputs = model_inputs.loc[valid_labels & valid_inputs]
model_labels = dataset[taskcol]
"""
targets = self.get_targets(task)
assert len(targets) > 0, "No target samples found."
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode)
dataset = targets[["y"]].join(features).dropna()
assert len(dataset) > 0, "No valid samples found after joining features and targets."
return dataset
if task == "binary":
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
elif task == "count":
binned = bin_values(model_labels.astype(int), task=task)
elif task == "density":
binned = bin_values(model_labels, task=task)
else:
raise ValueError("Invalid task.")
def create_training_set(
self,
task: Task,
device: Literal["cpu", "cuda", "torch"] = "cpu",
cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> TrainingSet:
"""Create a full training set for model training.
This function creates a full training set with features and target labels,
splits it into train and test sets, converts to arrays and moves to the specified device.
Args:
task (Task): The task.
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
Returns:
TrainingSet: The training set.
Contains a lot of useful information for training and evaluation.
"""
targets = self.get_targets(task)
assert len(targets) > 0, "No target samples found."
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
assert len(features) > 0, "No valid features found after dropping NaNs."
assert not features.isna().all("index").any(), "Some feature columns are all NaN"
# Create train / test split
train_idx, test_idx = train_test_split(dataset.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
split = pd.Series(index=dataset.index, dtype=object)
train_idx, test_idx = train_test_split(features.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
split = pd.Series(index=features.index, dtype=object)
split.loc[train_idx] = "train"
split.loc[test_idx] = "test"
split = split.astype("category")
X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64")
X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64")
y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64")
y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64")
X_train = features.loc[train_idx].to_numpy(dtype="float64")
X_test = features.loc[test_idx].to_numpy(dtype="float64")
if task in ["binary", "count_regimes", "density_regimes"]:
y_train = targets.loc[train_idx, "y"].cat.codes.to_numpy(dtype="int64")
y_test = targets.loc[test_idx, "y"].cat.codes.to_numpy(dtype="int64")
else:
y_train = targets.loc[train_idx, "y"].to_numpy(dtype="float64")
y_test = targets.loc[test_idx, "y"].to_numpy(dtype="float64")
if device == "cuda":
X_train = cp.asarray(X_train)
@ -578,28 +642,13 @@ class DatasetEnsemble:
y_test = torch.from_numpy(y_test).to(device=torch_device)
print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}")
else:
assert device == "cpu", "Invalid device specified."
assert device == "cpu", f"Invalid {device=} specified."
return CategoricalTrainingDataset(
dataset=dataset.to_crs("EPSG:4326"),
X=DatasetInputs(data=model_inputs, train=X_train, test=X_test),
y=DatasetLabels(binned=binned, train=y_train, test=y_test, raw_values=model_labels),
z=model_labels,
return TrainingSet(
targets=targets.loc[features.index].to_crs("EPSG:4326"),
features=features,
X=SplittedArrays(train=X_train, test=X_test),
y=SplittedArrays(train=y_train, test=y_test),
z=targets.loc[features.index, "z"],
split=split,
)
def create_cat_training_dataset(
self, task: Task, device: Literal["cpu", "cuda", "torch"]
) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
"""
dataset = self.create(filter_target_col=self.covcol)
return self._cat_and_split(dataset, task, device)

View file

@ -31,7 +31,7 @@ traceback.install()
pretty.install()
def open(grid: Grid, level: int):
def open(grid: Grid, level: int) -> gpd.GeoDataFrame:
"""Open a saved grid from parquet file.
Args:
@ -162,7 +162,7 @@ def create_global_hex_grid(resolution):
grid = gpd.GeoDataFrame(
{"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
crs="EPSG:4326",
)
) # ty:ignore[no-matching-overload]
return grid
@ -193,7 +193,7 @@ def create_global_healpix_grid(level: int):
geometry = healpix_ds.dggs.cell_boundaries()
# Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326")
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") # ty:ignore[no-matching-overload]
grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2
@ -250,18 +250,18 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
"""
fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()})
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree())
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) # ty:ignore[unresolved-attribute]
# Add features
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white")
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey")
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=":")
ax.add_feature(cfeature.LAKES, alpha=0.5)
ax.add_feature(cfeature.RIVERS)
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.COASTLINE) # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.BORDERS, linestyle=":") # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.LAKES, alpha=0.5) # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.RIVERS) # ty:ignore[unresolved-attribute]
# Add gridlines
gl = ax.gridlines(draw_labels=True)
gl = ax.gridlines(draw_labels=True) # ty:ignore[unresolved-attribute]
gl.top_labels = False
gl.right_labels = False
@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
verts = np.vstack([np.sin(theta), np.cos(theta)]).T
circle = mpath.Path(verts * radius + center)
ax.set_boundary(circle, transform=ax.transAxes)
ax.set_boundary(circle, transform=ax.transAxes) # ty:ignore[unresolved-attribute]
return fig

View file

@ -15,8 +15,9 @@ DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcodin
GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures")
RTS_DIR = DATA_DIR / "darts-rts"
RTS_LABELS_DIR = DATA_DIR / "darts-rts-mllabels"
DARTS_V1_DIR = DATA_DIR / "darts-v1"
DARTS_V2_DIR = DATA_DIR / "darts-v2"
DARTS_MLLABELS_DIR = DATA_DIR / "darts-mllabels"
ERA5_DIR = DATA_DIR / "era5"
ARCTICDEM_DIR = DATA_DIR / "arcticdem"
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
@ -27,7 +28,9 @@ RESULTS_DIR = DATA_DIR / "results"
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
RTS_DIR.mkdir(parents=True, exist_ok=True)
DARTS_V1_DIR.mkdir(parents=True, exist_ok=True)
DARTS_V2_DIR.mkdir(parents=True, exist_ok=True)
DARTS_MLLABELS_DIR.mkdir(parents=True, exist_ok=True)
ERA5_DIR.mkdir(parents=True, exist_ok=True)
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
@ -39,10 +42,6 @@ DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
dartsl2_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
def _get_gridname(grid: Grid, level: int) -> str:
return f"permafrost_{grid}{level}"
@ -60,13 +59,18 @@ def get_grid_viz_file(grid: Grid, level: int) -> Path:
return vizfile
def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path:
def get_darts_file(grid: Grid, level: int, version: Literal["v1", "v1-l3", "v2", "v2-l3", "mllabels"]) -> Path:
gridname = _get_gridname(grid, level)
if labels:
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
else:
rtsfile = RTS_DIR / f"{gridname}_darts.parquet"
return rtsfile
match version:
case "v1" | "v1-l3":
darts_file = DARTS_V1_DIR / f"{gridname}_darts-{version}.zarr"
case "v2" | "v2-l3":
darts_file = DARTS_V2_DIR / f"{gridname}_darts-{version}.zarr"
case "mllabels":
darts_file = DARTS_MLLABELS_DIR / f"{gridname}_darts-{version}.zarr"
case _:
raise ValueError(f"Unknown DARTS version: {version}")
return darts_file
def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
@ -85,13 +89,20 @@ def get_era5_stores(
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
grid: Grid | None = None,
level: int | None = None,
temporal: Literal["synopsis"] | None = None,
) -> Path:
if grid is None or level is None:
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True)
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr"
pdir = ERA5_DIR
if temporal is not None:
agg += f"_{temporal}" # ty:ignore[invalid-assignment]
fname = f"{agg}_climate.zarr"
gridname = _get_gridname(grid, level)
aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr"
if grid is None or level is None:
pdir = pdir / "intermediate"
pdir.mkdir(parents=True, exist_ok=True)
else:
gridname = _get_gridname(grid, level)
fname = f"{gridname}_{fname}"
aligned_path = pdir / fname
return aligned_path
@ -122,6 +133,12 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int
return cache_file
def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
cache_dir = DATASET_ENSEMBLES_DIR / "cache" / "features" / ensemble_id
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / f"cells_{cells_hash}.parquet"
def get_cv_results_dir(
name: str,
grid: Grid,
@ -133,3 +150,15 @@ def get_cv_results_dir(
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True)
return results_dir
def get_autogluon_results_dir(
grid: Grid,
level: int,
task: Task,
) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_autogluon_{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True)
return results_dir

View file

@ -15,11 +15,14 @@ type GridLevel = Literal[
"healpix9",
"healpix10",
]
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count", "density"]
type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
# TODO: Consider implementing a "timeseries" temporal mode
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
type Model = Literal["espa", "xgboost", "rf", "knn"]
type Stage = Literal["train", "inference", "visualization"]
@dataclass(frozen=True)
@ -70,8 +73,9 @@ class GridConfig:
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
all_tasks: list[Task] = ["binary", "count", "density"]
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
all_l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM",
"ERA5-shoulder",