Refactor dataset ensemble to allow different temporal modes

This commit is contained in:
Tobias Hölzer 2026-01-11 15:57:14 +01:00
parent 1495f71ac9
commit 231caa62e7
6 changed files with 708 additions and 505 deletions

View file

@ -1,15 +1,27 @@
#! /bin/bash #! /bin/bash
# pixi shell # pixi shell
darts extract-darts-rts --grid hex --level 3 darts extract-darts-v1 --grid hex --level 3
darts extract-darts-rts --grid hex --level 4 darts extract-darts-v1 --grid hex --level 4
darts extract-darts-rts --grid hex --level 5 darts extract-darts-v1 --grid hex --level 5
darts extract-darts-rts --grid hex --level 6 darts extract-darts-v1 --grid hex --level 6
darts extract-darts-rts --grid healpix --level 6 darts extract-darts-v1 --grid healpix --level 6
darts extract-darts-rts --grid healpix --level 7 darts extract-darts-v1 --grid healpix --level 7
darts extract-darts-rts --grid healpix --level 8 darts extract-darts-v1 --grid healpix --level 8
darts extract-darts-rts --grid healpix --level 9 darts extract-darts-v1 --grid healpix --level 9
darts extract-darts-rts --grid healpix --level 10 darts extract-darts-v1 --grid healpix --level 10
darts extract-darts-v1-aggregated --grid hex --level 3
darts extract-darts-v1-aggregated --grid hex --level 4
darts extract-darts-v1-aggregated --grid hex --level 5
darts extract-darts-v1-aggregated --grid hex --level 6
darts extract-darts-v1-aggregated --grid healpix --level 6
darts extract-darts-v1-aggregated --grid healpix --level 7
darts extract-darts-v1-aggregated --grid healpix --level 8
darts extract-darts-v1-aggregated --grid healpix --level 9
darts extract-darts-v1-aggregated --grid healpix --level 10
darts extract-darts-mllabels --grid hex --level 3 darts extract-darts-mllabels --grid hex --level 3
darts extract-darts-mllabels --grid hex --level 4 darts extract-darts-mllabels --grid hex --level 4

View file

@ -9,28 +9,140 @@ Date: October 2025
import cyclopts import cyclopts
import geopandas as gpd import geopandas as gpd
import pandas as pd import pandas as pd
from rich import pretty, print, traceback import xarray as xr
from rich.progress import track import xdggs
from rich import pretty, traceback
from stopuhr import stopwatch from stopuhr import stopwatch
from entropice.spatial import grids from entropice.spatial import grids
from entropice.utils.paths import ( from entropice.utils.paths import (
darts_ml_training_labels_repo, DARTS_MLLABELS_DIR,
dartsl2_cov_file, DARTS_V1_DIR,
dartsl2_file, get_darts_file,
get_darts_rts_file,
) )
from entropice.utils.types import Grid from entropice.utils.types import Grid
traceback.install() traceback.install()
pretty.install() pretty.install()
darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
darts_ml_training_labels_repo = DARTS_MLLABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
cli = cyclopts.App(name="darts-rts") cli = cyclopts.App(name="darts-rts")
def _load_grid(grid: Grid, level: int) -> tuple[gpd.GeoDataFrame, xr.DataArray]:
grid_gdf = grids.open(grid, level)
if grid == "hex":
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16))
cell_areas = xr.DataArray(
grid_gdf["cell_area"].to_numpy(), dims=["cell_id"], coords={"cell_id": grid_gdf["cell_id"].to_numpy()}
)
# We only want to use the geometry and the cell_id
grid_gdf = grid_gdf[["cell_id", "geometry"]]
return grid_gdf, cell_areas
@stopwatch("Convert xdgss")
def _convert_xdggs(darts: xr.Dataset, grid: Grid, level: int) -> xr.Dataset:
darts = darts.rename({"cell_id": "cell_ids"})
darts["cell_ids"] = darts["cell_ids"].astype("uint64")
gridinfo = {
"grid_name": "h3" if grid == "hex" else grid,
"level": level,
}
if grid == "healpix":
gridinfo["indexing_scheme"] = "nested"
darts.cell_ids.attrs = gridinfo
darts = xdggs.decode(darts)
return darts
@stopwatch("Process RTS grid")
def _process_rts_grid(
rts: gpd.GeoDataFrame,
cov: gpd.GeoDataFrame,
cell_areas: xr.DataArray,
) -> xr.Dataset:
assert "cell_id" in rts.columns, "RTS data must contain cell_id column."
assert "cell_id" in cov.columns, "Coverage data must contain cell_id column."
cov["area_km2"] = cov.geometry.area / 1e6
covered_area: pd.Series = cov.pivot_table(
index="cell_id",
values="area_km2",
aggfunc="sum",
fill_value=0,
)["area_km2"]
covered_area: xr.DataArray = xr.DataArray.from_series(covered_area)
coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id)
rts["area_km2"] = rts.geometry.area / 1e6
rts = rts.pivot_table(
index="cell_id",
values=["geometry", "area_km2"],
fill_value=0,
aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type]
)
rts: xr.Dataset = xr.Dataset.from_dataframe(rts)
darts = xr.merge(
[rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0
)
darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0)
return darts
@stopwatch("Process RTS yearly grid")
def _process_rts_yearly_grid(
rts: gpd.GeoDataFrame,
cov: gpd.GeoDataFrame,
cell_areas: xr.DataArray,
) -> xr.Dataset:
assert "cell_id" in rts.columns, "RTS data must contain cell_id column."
assert "cell_id" in cov.columns, "Coverage data must contain cell_id column."
cov["area_km2"] = cov.geometry.area / 1e6
covered_area: pd.Series = cov.pivot_table( # type: ignore[var-annotated] # noqa: PD010
index="cell_id",
columns="year",
values="area_km2",
aggfunc="sum",
fill_value=0,
).unstack()
covered_area: xr.DataArray = xr.DataArray.from_series(covered_area)
coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id)
rts: pd.DataFrame = (
rts.pivot_table( # type: ignore[var-annotated] # noqa: PD013
index="cell_id",
columns="year",
values=["geometry", "area_km2"],
fill_value=0,
aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type]
)
.stack(future_stack=True)
.rename(columns={"geometry": "count"})
)
rts: xr.Dataset = xr.Dataset.from_dataframe(rts)
darts = xr.merge(
[rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0
)
darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0)
return darts
@cli.command() @cli.command()
def extract_darts_rts(grid: Grid, level: int): def extract_darts_v1(grid: Grid, level: int):
"""Extract RTS labels from DARTS dataset. """Extract RTS labels from DARTS-v1 Level-2 dataset.
Creates a Darts-v1 xarray Dataset on the specified grid and level.
The Dataset contains the following variables:
- count: Number of RTS in the cell
- area_km2: Total area of RTS in the cell (in km^2)
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
- coverage: Fraction of the cell covered by DARTS
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
Since the DARTS-v1 Level-2 dataset contains yearly data, all variables are indexed by year as well.
Thus each variable has dimensions (cell_ids, year).
Args: Args:
grid (Grid): The grid type to use. grid (Grid): The grid type to use.
@ -38,66 +150,81 @@ def extract_darts_rts(grid: Grid, level: int):
""" """
with stopwatch("Load data"): with stopwatch("Load data"):
darts_l2 = gpd.read_parquet(dartsl2_file) darts_l2 = gpd.read_parquet(darts_v1_l2_file)
darts_cov_l2 = gpd.read_parquet(dartsl2_cov_file) darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
grid_gdf = grids.open(grid, level) grid_gdf, cell_areas = _load_grid(grid, level)
with stopwatch("Extract RTS labels"): with stopwatch("Assign RTS to grid"):
grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection") grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection") grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
years = list(grid_cov_l2["year"].unique()) darts = _process_rts_yearly_grid(grid_l2, grid_cov_l2, cell_areas)
for year in track(years, total=len(years), description="Processing years..."):
with stopwatch("Processing RTS", log=False):
subset = grid_l2[grid_l2["year"] == year]
subset_cov = grid_cov_l2[grid_cov_l2["year"] == year]
counts = subset.groupby("cell_id").size() darts = _convert_xdggs(darts, grid, level)
grid_gdf[f"darts_{year}_rts_count"] = grid_gdf.cell_id.map(counts) output_path = get_darts_file(grid, level, version="v1")
with stopwatch(f"Writing Darts v1 to {output_path}"):
darts.to_zarr(output_path, consolidated=False, mode="w")
areas = subset.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf[f"darts_{year}_rts_area"] = grid_gdf.cell_id.map(areas)
areas_cov = subset_cov.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False) @cli.command()
grid_gdf[f"darts_{year}_covered_area"] = grid_gdf.cell_id.map(areas_cov) def extract_darts_v1_aggregated(grid: Grid, level: int):
grid_gdf[f"darts_{year}_coverage"] = grid_gdf[f"darts_{year}_covered_area"] / grid_gdf.geometry.area """Extract RTS labels from DARTS-v1 Level-3 dataset.
grid_gdf[f"darts_{year}_rts_density"] = ( Creates a Darts-v1 xarray Dataset on the specified grid and level.
grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"] The Dataset contains the following variables:
) - count: Number of RTS in the cell
- area_km2: Total area of RTS in the cell (in km^2)
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
- coverage: Fraction of the cell covered by DARTS
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
Since the DARTS-v1 Level-2 dataset contains yearly data, the data is dissolved then exploded to obtain Level-3 data.
Thus each variable has only the dimension (cell_ids).
# Apply corrections to NaNs Args:
covered = ~grid_gdf[f"darts_{year}_coverage"].isna() grid (Grid): The grid type to use.
grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna( level (int): The grid level to use.
0.0
)
grid_gdf.loc[covered, f"darts_{year}_rts_density"] = grid_gdf.loc[
covered, f"darts_{year}_rts_density"
].fillna(0.0)
grid_gdf[f"darts_{year}_has_coverage"] = covered
grid_gdf[f"darts_{year}_has_rts"] = grid_gdf[f"darts_{year}_rts_count"] > 0
grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1) """
grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1) with stopwatch("Load data"):
darts_l2 = gpd.read_parquet(darts_v1_l2_file)
darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file)
grid_gdf, cell_areas = _load_grid(grid, level)
# Remove overlapping labels by dissolving
darts_l2 = darts_l2[["geometry"]].dissolve().explode()
darts_cov_l2 = darts_cov_l2[["geometry"]].dissolve().explode()
darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")] with stopwatch("Extract RTS labels"):
darts_counts = grid_gdf[darts_counts_columns] grid_l3 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection")
grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1) grid_cov_l3 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection")
darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")] darts = _process_rts_grid(grid_l3, grid_cov_l3, cell_areas)
darts_density = grid_gdf[darts_density_columns] darts = _convert_xdggs(darts, grid, level)
grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1) output_path = get_darts_file(grid, level, version="v1-l3")
with stopwatch(f"Writing Darts v1 l3 to {output_path}"):
output_path = get_darts_rts_file(grid, level) darts.to_zarr(output_path, consolidated=False, mode="w")
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")
stopwatch.summary()
@cli.command() @cli.command()
def extract_darts_mllabels(grid: Grid, level: int): def extract_darts_mllabels(grid: Grid, level: int):
"""Extract RTS labels from the DARTS-mllabels dataset.
Creates a Darts-mllabels xarray Dataset on the specified grid and level.
The Dataset contains the following variables:
- count: Number of RTS in the cell
- area_km2: Total area of RTS in the cell (in km^2)
- covered_area_km2: Area of the cell covered by DARTS (in km^2)
- coverage: Fraction of the cell covered by DARTS
- density: Density of RTS area per covered area (area_km2 / covered_area_km2)
Since the DARTS-mllabels dataset contains not much data, all variables are aggregated over the entire time period.
Thus each variable has only the dimension (cell_ids).
Args:
grid (Grid): The grid type to use.
level (int): The grid level to use.
"""
with stopwatch("Load data"): with stopwatch("Load data"):
grid_gdf = grids.open(grid, level) grid_gdf, cell_areas = _load_grid(grid, level)
darts_mllabels = ( darts_mllabels = (
gpd.GeoDataFrame( gpd.GeoDataFrame(
pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")]) pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")])
@ -118,34 +245,16 @@ def extract_darts_mllabels(grid: Grid, level: int):
.to_crs(grid_gdf.crs) .to_crs(grid_gdf.crs)
) )
darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode() darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode()
with stopwatch("Extract RTS labels"): with stopwatch("Extract RTS labels"):
grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection") grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection")
grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection") grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection")
with stopwatch("Processing RTS"): darts = _process_rts_grid(grid_mllabels, grid_cov_mllabels, cell_areas)
counts = grid_mllabels.groupby("cell_id").size() darts = _convert_xdggs(darts, grid, level)
grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts) output_path = get_darts_file(grid, level, version="mllabels")
with stopwatch(f"Writing Darts v1 to {output_path}"):
areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False) darts.to_zarr(output_path, consolidated=False, mode="w")
grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas)
areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False)
grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov)
grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area
grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"]
# Apply corrections to NaNs
covered = ~grid_gdf["dartsml_coverage"].isna()
grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0)
grid_gdf.loc[covered, "dartsml_rts_density"] = grid_gdf.loc[covered, "dartsml_rts_density"].fillna(0.0)
grid_gdf["dartsml_has_coverage"] = covered
grid_gdf["dartsml_has_rts"] = grid_gdf["dartsml_rts_count"] > 0
output_path = get_darts_rts_file(grid, level, labels=True)
grid_gdf.to_parquet(output_path)
print(f"Saved RTS labels to {output_path}")
stopwatch.summary()
def main(): # noqa: D103 def main(): # noqa: D103

View file

@ -1,4 +1,4 @@
# ruff: noqa: N806 # ruff: noqa: N806, D105
"""Training dataset preparation and model training. """Training dataset preparation and model training.
Naming conventions: Naming conventions:
@ -16,9 +16,9 @@ import hashlib
import json import json
from collections.abc import Generator from collections.abc import Generator
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from functools import cached_property from functools import cache, cached_property
from itertools import product from itertools import product
from typing import Literal, TypedDict from typing import Literal, cast
import cupy as cp import cupy as cp
import cyclopts import cyclopts
@ -33,8 +33,9 @@ from sklearn import set_config
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from stopuhr import stopwatch from stopuhr import stopwatch
import entropice.spatial.grids
import entropice.utils.paths import entropice.utils.paths
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode
traceback.install() traceback.install()
pretty.install() pretty.install()
@ -44,42 +45,53 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid") sns.set_theme("talk", "whitegrid")
covcol: dict[TargetDataset, str] = { def _unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset:
"darts_rts": "darts_has_coverage", # In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord
"darts_mllabels": "dartsml_has_coverage", if aggregation == "yearly":
} era5 = era5.rename({"time": "year"})
era5.coords["year"] = era5["year"].dt.year
return era5
taskcol: dict[Task, dict[TargetDataset, str]] = { # Make the time index a MultiIndex of year and month
"binary": { era5.coords["year"] = era5.time.dt.year
"darts_rts": "darts_has_rts", era5.coords["month"] = era5.time.dt.month
"darts_mllabels": "dartsml_has_rts", era5["time"] = pd.MultiIndex.from_arrays(
}, [
"count": { era5.time.dt.year.values, # noqa: PD011
"darts_rts": "darts_rts_count", era5.time.dt.month.values, # noqa: PD011
"darts_mllabels": "dartsml_rts_count", ],
}, names=("year", "month"),
"density": { )
"darts_rts": "darts_rts_density", era5 = era5.unstack("time") # noqa: PD010
"darts_mllabels": "dartsml_rts_density",
},
}
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
if temporal == "yearly":
return time_index.year
elif temporal == "seasonal":
seasons = {10: "winter", 4: "summer"} seasons = {10: "winter", 4: "summer"}
return time_index.month.map(lambda x: seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
elif temporal == "shoulder":
shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"}
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_") month_map = seasons if aggregation == "seasonal" else shoulder_seasons
era5.coords["month"] = era5["month"].to_series().map(month_map)
return era5
def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame:
collapsed = ds.to_dataframe()
# Make a dummy row to avoid empty dataframe issues
use_dummy = collapsed.shape[0] == 0
if use_dummy:
collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan
pivcols = set(collapsed.index.names) - {"cell_ids"}
collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols)
collapsed.columns = ["_".join(v) for v in collapsed.columns]
if use_dummy:
collapsed = collapsed.dropna(how="all")
return collapsed
def _cell_ids_hash(cell_ids: pd.Series) -> str:
sorted_ids = np.sort(cell_ids.to_numpy())
return hashlib.blake2b(sorted_ids.tobytes(), digest_size=8).hexdigest()
def bin_values( def bin_values(
values: pd.Series, values: pd.Series,
task: Literal["count", "density"], task: Literal["count_regimes", "density_regimes"],
none_val: float = 0, none_val: float = 0,
) -> pd.Series: ) -> pd.Series:
"""Bin values into predefined intervals for different tasks. """Bin values into predefined intervals for different tasks.
@ -89,7 +101,7 @@ def bin_values(
Args: Args:
values (pd.Series): Pandas Series of numerical values to bin. values (pd.Series): Pandas Series of numerical values to bin.
task (Literal["count", "density"]): Task type - 'count' or 'density'. task (Literal["count_regimes", "density_regimes"]): Task type - 'count_regimes' or 'density_regimes'.
none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0. none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0.
Returns: Returns:
@ -100,15 +112,8 @@ def bin_values(
""" """
labels_dict = { labels_dict = {
"count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"], "count_regimes": ["None", "Very Few", "Few", "Several", "Many", "Very Many"],
"density": [ "density_regimes": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"],
"Empty",
"Very Sparse",
"Sparse",
"Moderate",
"Dense",
"Very Dense",
],
} }
labels = labels_dict[task] labels = labels_dict[task]
@ -138,62 +143,67 @@ def bin_values(
@dataclass(frozen=True, eq=False) @dataclass(frozen=True, eq=False)
class DatasetLabels: class SplittedArrays:
binned: pd.Series """Small wrapper for train and test arrays."""
train: torch.Tensor | np.ndarray | cp.ndarray train: torch.Tensor | np.ndarray | cp.ndarray
test: torch.Tensor | np.ndarray | cp.ndarray test: torch.Tensor | np.ndarray | cp.ndarray
raw_values: pd.Series
@dataclass(frozen=True, eq=False)
class TrainingSet:
"""Container for the training dataset."""
targets: gpd.GeoDataFrame
features: pd.DataFrame
X: SplittedArrays
y: SplittedArrays
z: pd.Series
split: pd.Series
@cached_property @cached_property
def intervals(self) -> list[tuple[float, float] | tuple[int, int]]: def target_intervals(self) -> list[tuple[float, float] | tuple[int, int] | tuple[None, None]]:
"""Calculate the intervals for each target category.
Returns:
list[tuple[float, float] | tuple[int, int] | tuple[None, None]]:
List of (min, max) tuples for each category.
"""
binned = self.targets["y"]
raw = self.targets["z"]
assert binned.dtype.name == "category", "Target labels are not categorical."
# For each category get the min and max values from raw_values # For each category get the min and max values from raw_values
intervals = [] intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] = []
for category in self.binned.cat.categories: for category in binned.cat.categories:
category_mask = self.binned == category category_mask = binned == category
if category_mask.sum() == 0: if category_mask.sum() == 0:
intervals.append((None, None)) intervals.append((None, None))
else: else:
category_raw_values = self.raw_values[category_mask] category_raw_values = raw[category_mask]
intervals.append((category_raw_values.min(), category_raw_values.max())) intervals.append((category_raw_values.min(), category_raw_values.max()))
return intervals return intervals
@cached_property @cached_property
def labels(self) -> list[str]: def target_labels(self) -> list[str]:
return list(self.binned.cat.categories) """Labels of the target categories."""
binned = self.targets["y"]
assert binned.dtype.name == "category", "Target labels are not categorical."
@dataclass(frozen=True, eq=False) return list(binned.cat.categories)
class DatasetInputs:
data: pd.DataFrame
train: torch.Tensor | np.ndarray | cp.ndarray
test: torch.Tensor | np.ndarray | cp.ndarray
@dataclass(frozen=True)
class CategoricalTrainingDataset:
dataset: gpd.GeoDataFrame
X: DatasetInputs
y: DatasetLabels
z: pd.Series
split: pd.Series
def __len__(self): def __len__(self):
return len(self.z) return len(self.z)
class DatasetStats(TypedDict):
target: str
num_target_samples: int
members: dict[str, dict[str, object]]
total_features: int
@cyclopts.Parameter("*") @cyclopts.Parameter("*")
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
class DatasetEnsemble: class DatasetEnsemble:
"""An ensemble of datasets for training and inference."""
grid: Grid grid: Grid
level: int level: int
target: Literal["darts_rts", "darts_mllabels"] target: TargetDataset
members: list[L2SourceDataset] = field( members: list[L2SourceDataset] = field(
default_factory=lambda: [ default_factory=lambda: [
"AlphaEarth", "AlphaEarth",
@ -203,366 +213,420 @@ class DatasetEnsemble:
"ERA5-shoulder", "ERA5-shoulder",
] ]
) )
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) temporal_mode: TemporalMode = "synopsis"
variable_filters: dict[str, list[str]] = field(default_factory=dict) dimension_filters: dict[L2SourceDataset, dict[str, list]] = field(default_factory=dict)
filter_target: str | Literal[False] = False variable_filters: dict[L2SourceDataset, list[str]] = field(default_factory=dict)
add_lonlat: bool = True add_lonlat: bool = True
def __post_init__(self):
# Validate filters
for member in self.members:
if member in self.variable_filters:
# This enforces the assumption that the read members always are xarray Datasets and not DataArrays
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
" Variable filter values must be a list with one or more entries."
)
if member in self.dimension_filters:
for dim, values in self.dimension_filters[member].items():
# This enforces the assumption that we do the temporal filtering via the temporal_mode
assert dim not in ["year", "month", "time"], (
f"Invalid dimension filter for {member=}: {dim}"
" Filtering on 'year', 'month' or 'time' is not supported."
)
# This enforces the assumption that there are no empty dimensions in the Dataset
assert isinstance(values, list) and len(values) >= 1, (
f"Invalid dimension filter for {dim=}: {values}"
" Dimension filter values must be a list with one or more entries."
)
def __hash__(self): def __hash__(self):
return int(self.id(), 16) return int(self.id(), 16)
@cache
def id(self): def id(self):
"""Return an unique, stable identifier based on the settings of this class."""
return hashlib.blake2b( return hashlib.blake2b(
json.dumps(asdict(self), sort_keys=True).encode("utf-8"), json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
digest_size=16, digest_size=16,
).hexdigest() ).hexdigest()
@property @cached_property
def covcol(self) -> str: def cell_ids(self) -> pd.Series:
return covcol[self.target] return self.read_grid()["cell_id"]
def taskcol(self, task: Task) -> str: @cached_property
return taskcol[task][self.target] def geometries(self) -> pd.Series:
return self.read_grid()["geometry"]
def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: # @stopwatch("Reading grid")
if member == "AlphaEarth": def read_grid(self) -> gpd.GeoDataFrame:
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
# Add the lat / lon of the cell centers
if self.add_lonlat:
grid_gdf["lon"] = grid_gdf.geometry.centroid.x
grid_gdf["lat"] = grid_gdf.geometry.centroid.y
# Convert hex cell_id to int
if self.grid == "hex":
grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64)
grid_gdf = grid_gdf.set_index("cell_id")
return grid_gdf
# @stopwatch.f("Getting target labels", print_kwargs=["stage"])
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
"""Create a training target labels for a specific task.
The function reads the target dataset, filters it based on coverage,
and prepares the target labels according to the specified task.
This is quite a fast function, with the bottleneck usually being the reading of the zarr store
(milliseconds to few seconds).
This is intended to be used by another training Dataset creation function
and for easy visualization of the targets.
Args:
task (Task): The task.
Returns:
gpd.GeoDataFrame: GeoDataFrame with target labels, geometry and raw-target values,
all indexed by the cell_ids.
Columns: ["geometry", "y", "z"], where "y" is the target label and "z" the raw value.
For categorical tasks, "y" is categorical and "z" is numerical.
For regression tasks, both "y" and "z" are numerical and identical.
"""
match (self.target, self.temporal_mode):
case ("darts_v1" | "darts_v2", "feature" | "synopsis"):
version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False)
case ("darts_v1" | "darts_v2", int()):
version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment]
target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version)
targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode)
case ("darts_mllabels", str()): # Years are not supported
target_store = entropice.utils.paths.get_darts_file(
grid=self.grid, level=self.level, version="mllabels"
)
targets = xr.open_zarr(target_store, consolidated=False)
case _:
raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.")
targets = cast(xr.Dataset, targets)
covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series()
targets = targets.sel(cell_ids=covered_cell_ids.to_numpy())
match task:
case "binary":
z = targets["count"].to_series()
y = (z > 0).astype("category").map({False: "No RTS", True: "RTS"})
case "count_regimes":
z = targets["count"].to_series()
y = bin_values(z, task="count_regimes", none_val=0)
case "density_regimes":
z = targets["density"].to_series()
y = bin_values(z, task="density_regimes", none_val=0)
case "count":
z = targets["count"].to_series()
y = z
case "density":
z = targets["density"].to_series()
y = z
case _:
raise NotImplementedError(f"Task {task} not supported.")
cell_ids = targets["cell_ids"].to_series()
geometries = self.geometries.loc[cell_ids]
return gpd.GeoDataFrame(
{
"cell_id": cell_ids,
"geometry": geometries,
"y": y,
"z": z,
}
).set_index("cell_id")
# @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"])
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
When `lazy` is False, the data is actually read into memory, which can take some time (seconds).
Args:
member (L2SourceDataset): The member
cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read.
Defaults to None.
lazy (bool, optional): Whether to not load the data and instead return a lazy dataset. Defaults to False.
Returns:
xr.Dataset: Xarray Dataset with the member data for the given cell IDs.
"""
match member:
case "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level) store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder":
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level) store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level)
elif member == "ArcticDEM": case "ArcticDEM":
store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level) store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level)
else: case _:
raise NotImplementedError(f"Member {member} not implemented.") raise NotImplementedError(f"Member {member} not implemented.")
ds = xr.open_zarr(store, consolidated=False) ds = xr.open_zarr(store, consolidated=False)
# Apply variable filters # Apply variable and dimension filters
if member in self.variable_filters: if member in self.variable_filters:
assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, (
f"Invalid variable filter for {member=}: {self.variable_filters[member]}"
" Variable filter values must be a list with one or more entries."
)
ds = ds[self.variable_filters[member]] ds = ds[self.variable_filters[member]]
# Apply dimension filters
if member in self.dimension_filters: if member in self.dimension_filters:
for dim, values in self.dimension_filters[member].items(): ds = ds.sel(self.dimension_filters[member])
assert isinstance(values, list) and len(values) >= 1, (
f"Invalid dimension filter for {dim=}: {values}"
" Dimension filter values must be a list with one or more entries."
)
ds = ds.sel({dim: values})
# Delete all coordinates which are not in the dimension # Delete all coordinates which are not in the dimension
for coord in ds.coords: ds = ds.drop_vars([coord for coord in ds.coords if coord not in ds.dims])
if coord not in ds.dims:
ds = ds.drop_vars(coord)
# Only load target cell ids # Only load target cell ids
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values)) if cell_ids is None:
cell_ids = self.cell_ids
intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(cell_ids.to_numpy()))
ds = ds.sel(cell_ids=list(intersecting_cell_ids)) ds = ds.sel(cell_ids=list(intersecting_cell_ids))
# Unstack era5 data if needed
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
ds = _unstack_era5_time(ds, era5_agg)
# Apply the temporal mode
match (member.split("-"), self.temporal_mode):
case (["ERA5", _] | ["AlphaEarth"], "synopsis"):
ds_mean = ds.mean(dim="year")
ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True)
# Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend"
ds_trend = ds_trend.rename(
{var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars}
)
ds = xr.merge([ds_mean, ds_trend])
case (["ArcticDEM"], "synopsis"):
pass # No temporal dimension
case (_, int() as year):
ds = ds.sel(year=year, drop=True)
case (_, "feature"):
pass
case _:
raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.")
# Actually read data into memory # Actually read data into memory
if not lazy: if not lazy:
ds.load() ds.load()
return ds return ds
def _read_target(self) -> gpd.GeoDataFrame: def make_features(
if self.target == "darts_rts": self,
target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level) cell_ids: pd.Series | None = None,
elif self.target == "darts_mllabels": cache_mode: Literal["none", "read", "overwrite"] = "none",
target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True) ) -> pd.DataFrame:
else: """Create a feature DataFrame for the given temporal task and cell IDs.
raise NotImplementedError(f"Target {self.target} not implemented.")
targets = gpd.read_parquet(target_store)
# Filter to coverage This function reads all members of the ensemble, prepares their features based on the temporal task,
if self.filter_target: and combines them into a single DataFrame (no geometry).
targets = targets[targets[self.filter_target]] It is quite computation intensive (seconds to minutes), depending on the configuration and number of cell IDs.
# Convert hex cell_id to int To speed up repeated calls, a caching mechanism is implemented.
if self.grid == "hex":
targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16))
# Add the lat / lon of the cell centers The indented use for this function is solely the creation of training and inference datasets.
if self.add_lonlat: For visualization purposes it is recommended to use the `read_member` with `lazy=True` functions directly.
targets["lon"] = targets.geometry.centroid.x
targets["lat"] = targets.geometry.centroid.y
return targets Args:
cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read.
Defaults to None.
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
Returns:
pd.DataFrame: DataFrame with features for the given cell IDs.
"""
if cell_ids is None:
cell_ids = self.cell_ids
# Caching mechanism
if cache_mode != "none":
cells_hash = _cell_ids_hash(cell_ids)
cache_file = entropice.utils.paths.get_features_cache(ensemble_id=self.id(), cells_hash=cells_hash)
if cache_mode == "read" and cache_file.exists():
with stopwatch("Loading features from cache"):
return pd.read_parquet(cache_file)
# Create features
dataset = self.read_grid().loc[cell_ids].drop(columns=["geometry"])
member_dfs = []
for member in self.members:
match member.split("-"):
case ["ERA5", era5_agg]:
member_dfs.append(self._prep_era5(cell_ids, era5_agg))
case ["AlphaEarth"]:
member_dfs.append(self._prep_embeddings(cell_ids))
case ["ArcticDEM"]:
member_dfs.append(self._prep_arcticdem(cell_ids))
case _:
raise NotImplementedError(f"Member {member} not implemented.")
with stopwatch("Joining datasets"):
dataset = dataset.join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["read", "overwrite"]:
with stopwatch("Saving features to cache"):
dataset.to_parquet(cache_file)
return dataset
# @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"])
def _prep_era5( def _prep_era5(
self, self,
targets: gpd.GeoDataFrame, cell_ids: pd.Series,
temporal: Literal["yearly", "seasonal", "shoulder"], era5_agg: Literal["yearly", "seasonal", "shoulder"],
) -> pd.DataFrame: ) -> pd.DataFrame:
era5 = self._read_member("ERA5-" + temporal, targets) era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False)
era5_df = _collapse_to_dataframe(era5)
if len(era5["cell_ids"]) == 0: era5_df.columns = [f"era5_{col}" for col in era5_df.columns]
# No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure
variables = list(era5.data_vars)
times = era5.coords["time"].to_numpy()
time_df = pd.DataFrame({"time": times})
time_df.index = pd.DatetimeIndex(times)
tempus = _get_era5_tempus(time_df, temporal)
unique_tempus = tempus.unique()
if "aggregations" in era5.dims:
aggs_list = era5.coords["aggregations"].to_numpy()
expected_cols = [
f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list)
]
else:
expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
era5_df = era5.to_dataframe()
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
if "aggregations" not in era5.dims:
era5_df = era5_df.pivot_table(index="cell_ids", columns="t")
era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns]
else:
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN # Ensure all target cell_ids are present, fill missing with NaN
era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan) era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return era5_df return era5_df
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: # @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"])
embeddings = self._read_member("AlphaEarth", targets)["embeddings"] def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
if len(embeddings["cell_ids"]) == 0: embeddings_df = _collapse_to_dataframe(embeddings)
# No data for these cells - create empty DataFrame with expected columns embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns]
# Use the Dataset metadata to determine column structure
years = embeddings.coords["year"].to_numpy()
aggs = embeddings.coords["agg"].to_numpy()
bands = embeddings.coords["band"].to_numpy()
expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)]
return pd.DataFrame(
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN # Ensure all target cell_ids are present, fill missing with NaN
embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan) embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return embeddings_df return embeddings_df
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: # @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"])
arcticdem = self._read_member("ArcticDEM", targets) def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
if len(arcticdem["cell_ids"]) == 0: if len(arcticdem["cell_ids"]) == 0:
# No data for these cells - create empty DataFrame with expected columns # No data for these cells - create empty DataFrame with expected columns
# Use the Dataset metadata to determine column structure # Use the Dataset metadata to determine column structure
variables = list(arcticdem.data_vars) variables = list(arcticdem.data_vars)
aggs = arcticdem.coords["aggregations"].to_numpy() aggs = arcticdem.coords["aggregations"].to_numpy()
expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)] expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)]
return pd.DataFrame( return pd.DataFrame(index=cell_ids.to_numpy(), columns=expected_cols, dtype=float)
index=targets["cell_id"].values,
columns=expected_cols,
dtype=float,
)
arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations") arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations")
arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns] arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns]
# Ensure all target cell_ids are present, fill missing with NaN # Ensure all target cell_ids are present, fill missing with NaN
arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan) arcticdem_df = arcticdem_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return arcticdem_df return arcticdem_df
def get_stats(self) -> DatasetStats: def create_inference_df(
"""Get dataset statistics. self,
batch_size: int | None = None,
cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> Generator[pd.DataFrame]:
"""Create an inference feature set generator.
Returns: This function creates features for all cell IDs in batches.
DatasetStats: Dictionary containing target stats, member stats, and total features count. If `batch_size` is None or greater than the number of cell IDs, all data is returned in a single batch.
Args:
batch_size (int | None, optional): The batch size. If None, all data is returned in a single batch.
Defaults to None.
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
Yields:
Generator[pd.DataFrame]: Generator yielding DataFrames with features for inference.
""" """
targets = self._read_target() all_cell_ids = self.cell_ids
stats: DatasetStats = { if batch_size is None or batch_size >= len(all_cell_ids):
"target": self.target, yield self.make_features(cell_ids=all_cell_ids, cache_mode=cache_mode)
"num_target_samples": len(targets),
"members": {},
"total_features": 2 if self.add_lonlat else 0, # Lat and Lon
}
for member in self.members:
ds = self._read_member(member, targets, lazy=True)
n_cols_member = len(ds.data_vars)
for dim in ds.sizes:
if dim != "cell_ids":
n_cols_member *= ds.sizes[dim]
stats["members"][member] = {
"variables": list(ds.data_vars),
"num_variables": len(ds.data_vars),
"dimensions": dict(ds.sizes),
"coordinates": list(ds.coords),
"num_features": n_cols_member,
}
stats["total_features"] += n_cols_member
return stats
def print_stats(self):
stats = self.get_stats()
print(f"=== Target: {stats['target']}")
print(f"\tNumber of target samples: {stats['num_target_samples']}")
for member, member_stats in stats["members"].items():
print(f"=== Member: {member}")
print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}")
print(f"\tDimensions: {member_stats['dimensions']}")
print(f"\tCoordinates: {member_stats['coordinates']}")
print(f"\tNumber of features from member: {member_stats['num_features']}")
print(f"=== Total number of features in dataset: {stats['total_features']}")
def create(
self,
filter_target_col: str | None = None,
cache_mode: Literal["n", "o", "r"] = "r",
) -> gpd.GeoDataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
if cache_mode == "r" and cache_file.exists():
with stopwatch("Loading dataset from cache"):
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
return dataset
with stopwatch("Reading target"):
targets = self._read_target()
if filter_target_col is not None:
targets = targets.loc[targets[filter_target_col]]
print(f"Read and filtered target dataset. ({len(targets)} samples)")
with stopwatch("Preparing member datasets"):
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
member_dfs.append(self._prep_era5(targets, era5_agg))
elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets))
elif member == "ArcticDEM":
member_dfs.append(self._prep_arcticdem(targets))
else:
raise NotImplementedError(f"Member {member} not implemented.")
print("Prepared all member datasets. Joining...")
with stopwatch("Joining datasets"):
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
print("Joining complete.")
if cache_mode in ["o", "r"]:
with stopwatch("Saving dataset to cache"):
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
return dataset
def create_batches(
self,
batch_size: int,
filter_target_col: str | None = None,
cache_mode: Literal["n", "o", "r"] = "r",
) -> Generator[pd.DataFrame]:
targets = self._read_target()
if len(targets) == 0:
raise ValueError("No target samples found.")
elif len(targets) < batch_size:
yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode)
return return
if filter_target_col is not None: for i in range(0, len(all_cell_ids), batch_size):
targets = targets.loc[targets[filter_target_col]] batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
for i in range(0, len(targets), batch_size): def create_training_df(
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(
self.id(), subset=filter_target_col, batch=(i, i + batch_size)
)
if cache_mode == "r" and cache_file.exists():
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
yield dataset
else:
targets_batch = targets.iloc[i : i + batch_size]
member_dfs = []
for member in self.members:
if member.startswith("ERA5"):
era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment]
member_dfs.append(self._prep_era5(targets_batch, era5_agg))
elif member == "AlphaEarth":
member_dfs.append(self._prep_embeddings(targets_batch))
elif member == "ArcticDEM":
member_dfs.append(self._prep_arcticdem(targets_batch))
else:
raise NotImplementedError(f"Member {member} not implemented.")
dataset = targets_batch.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["o", "r"]:
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
yield dataset
def _cat_and_split(
self, self,
dataset: gpd.GeoDataFrame,
task: Task, task: Task,
device: Literal["cpu", "cuda", "torch"], cache_mode: Literal["none", "overwrite", "read"] = "read",
) -> CategoricalTrainingDataset: ) -> pd.DataFrame:
taskcol = self.taskcol(task) """Create a simple DataFrame for training with e.g. autogluon.
valid_labels = dataset[taskcol].notna() This function creates a DataFrame with features and target labels (called "y") for training.
The resulting dataframe contains no geometry information and drops all samples with NaN values.
Does not split into train and test sets, also does not convert to arrays.
For more advanced use cases, use `create_training_set`.
cols_to_drop = {"geometry", taskcol, self.covcol} Args:
cols_to_drop |= { task (Task): The task.
col cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
for col in dataset.columns "none": No caching.
if col.startswith("dartsml_" if self.target == "darts_mllabels" else "darts_") "read": Read from cache if exists, otherwise create and save to cache.
} "overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
model_inputs = dataset.drop(columns=cols_to_drop) Returns:
# Assert that no column in all-nan pd.DataFrame: The training DataFrame.
assert not model_inputs.isna().all("index").any(), "Some input columns are all NaN"
# Get valid inputs (rows)
valid_inputs = model_inputs.notna().all("columns")
dataset = dataset.loc[valid_labels & valid_inputs] """
model_inputs = model_inputs.loc[valid_labels & valid_inputs] targets = self.get_targets(task)
model_labels = dataset[taskcol] assert len(targets) > 0, "No target samples found."
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode)
dataset = targets[["y"]].join(features).dropna()
assert len(dataset) > 0, "No valid samples found after joining features and targets."
return dataset
if task == "binary": def create_training_set(
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category") self,
elif task == "count": task: Task,
binned = bin_values(model_labels.astype(int), task=task) device: Literal["cpu", "cuda", "torch"] = "cpu",
elif task == "density": cache_mode: Literal["none", "overwrite", "read"] = "read",
binned = bin_values(model_labels, task=task) ) -> TrainingSet:
else: """Create a full training set for model training.
raise ValueError("Invalid task.")
This function creates a full training set with features and target labels,
splits it into train and test sets, converts to arrays and moves to the specified device.
Args:
task (Task): The task.
device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu".
cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation.
"none": No caching.
"read": Read from cache if exists, otherwise create and save to cache.
"overwrite": Always create and save to cache, overwriting existing cache.
Defaults to "none".
Returns:
TrainingSet: The training set.
Contains a lot of useful information for training and evaluation.
"""
targets = self.get_targets(task)
assert len(targets) > 0, "No target samples found."
features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna()
assert len(features) > 0, "No valid features found after dropping NaNs."
assert not features.isna().all("index").any(), "Some feature columns are all NaN"
# Create train / test split # Create train / test split
train_idx, test_idx = train_test_split(dataset.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True) train_idx, test_idx = train_test_split(features.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True)
split = pd.Series(index=dataset.index, dtype=object) split = pd.Series(index=features.index, dtype=object)
split.loc[train_idx] = "train" split.loc[train_idx] = "train"
split.loc[test_idx] = "test" split.loc[test_idx] = "test"
split = split.astype("category") split = split.astype("category")
X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64") X_train = features.loc[train_idx].to_numpy(dtype="float64")
X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64") X_test = features.loc[test_idx].to_numpy(dtype="float64")
y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64") if task in ["binary", "count_regimes", "density_regimes"]:
y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64") y_train = targets.loc[train_idx, "y"].cat.codes.to_numpy(dtype="int64")
y_test = targets.loc[test_idx, "y"].cat.codes.to_numpy(dtype="int64")
else:
y_train = targets.loc[train_idx, "y"].to_numpy(dtype="float64")
y_test = targets.loc[test_idx, "y"].to_numpy(dtype="float64")
if device == "cuda": if device == "cuda":
X_train = cp.asarray(X_train) X_train = cp.asarray(X_train)
@ -578,28 +642,13 @@ class DatasetEnsemble:
y_test = torch.from_numpy(y_test).to(device=torch_device) y_test = torch.from_numpy(y_test).to(device=torch_device)
print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}") print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}")
else: else:
assert device == "cpu", "Invalid device specified." assert device == "cpu", f"Invalid {device=} specified."
return CategoricalTrainingDataset( return TrainingSet(
dataset=dataset.to_crs("EPSG:4326"), targets=targets.loc[features.index].to_crs("EPSG:4326"),
X=DatasetInputs(data=model_inputs, train=X_train, test=X_test), features=features,
y=DatasetLabels(binned=binned, train=y_train, test=y_test, raw_values=model_labels), X=SplittedArrays(train=X_train, test=X_test),
z=model_labels, y=SplittedArrays(train=y_train, test=y_test),
z=targets.loc[features.index, "z"],
split=split, split=split,
) )
def create_cat_training_dataset(
self, task: Task, device: Literal["cpu", "cuda", "torch"]
) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
"""
dataset = self.create(filter_target_col=self.covcol)
return self._cat_and_split(dataset, task, device)

View file

@ -31,7 +31,7 @@ traceback.install()
pretty.install() pretty.install()
def open(grid: Grid, level: int): def open(grid: Grid, level: int) -> gpd.GeoDataFrame:
"""Open a saved grid from parquet file. """Open a saved grid from parquet file.
Args: Args:
@ -162,7 +162,7 @@ def create_global_hex_grid(resolution):
grid = gpd.GeoDataFrame( grid = gpd.GeoDataFrame(
{"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, {"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list},
crs="EPSG:4326", crs="EPSG:4326",
) ) # ty:ignore[no-matching-overload]
return grid return grid
@ -193,7 +193,7 @@ def create_global_healpix_grid(level: int):
geometry = healpix_ds.dggs.cell_boundaries() geometry = healpix_ds.dggs.cell_boundaries()
# Create GeoDataFrame # Create GeoDataFrame
grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") # ty:ignore[no-matching-overload]
grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2 grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2
@ -250,18 +250,18 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
""" """
fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()}) fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()})
ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) # ty:ignore[unresolved-attribute]
# Add features # Add features
ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.COASTLINE) # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.BORDERS, linestyle=":") ax.add_feature(cfeature.BORDERS, linestyle=":") # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.LAKES, alpha=0.5) ax.add_feature(cfeature.LAKES, alpha=0.5) # ty:ignore[unresolved-attribute]
ax.add_feature(cfeature.RIVERS) ax.add_feature(cfeature.RIVERS) # ty:ignore[unresolved-attribute]
# Add gridlines # Add gridlines
gl = ax.gridlines(draw_labels=True) gl = ax.gridlines(draw_labels=True) # ty:ignore[unresolved-attribute]
gl.top_labels = False gl.top_labels = False
gl.right_labels = False gl.right_labels = False
@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
verts = np.vstack([np.sin(theta), np.cos(theta)]).T verts = np.vstack([np.sin(theta), np.cos(theta)]).T
circle = mpath.Path(verts * radius + center) circle = mpath.Path(verts * radius + center)
ax.set_boundary(circle, transform=ax.transAxes) ax.set_boundary(circle, transform=ax.transAxes) # ty:ignore[unresolved-attribute]
return fig return fig

View file

@ -15,8 +15,9 @@ DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcodin
GRIDS_DIR = DATA_DIR / "grids" GRIDS_DIR = DATA_DIR / "grids"
FIGURES_DIR = Path("figures") FIGURES_DIR = Path("figures")
RTS_DIR = DATA_DIR / "darts-rts" DARTS_V1_DIR = DATA_DIR / "darts-v1"
RTS_LABELS_DIR = DATA_DIR / "darts-rts-mllabels" DARTS_V2_DIR = DATA_DIR / "darts-v2"
DARTS_MLLABELS_DIR = DATA_DIR / "darts-mllabels"
ERA5_DIR = DATA_DIR / "era5" ERA5_DIR = DATA_DIR / "era5"
ARCTICDEM_DIR = DATA_DIR / "arcticdem" ARCTICDEM_DIR = DATA_DIR / "arcticdem"
EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR = DATA_DIR / "embeddings"
@ -27,7 +28,9 @@ RESULTS_DIR = DATA_DIR / "results"
GRIDS_DIR.mkdir(parents=True, exist_ok=True) GRIDS_DIR.mkdir(parents=True, exist_ok=True)
FIGURES_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True)
RTS_DIR.mkdir(parents=True, exist_ok=True) DARTS_V1_DIR.mkdir(parents=True, exist_ok=True)
DARTS_V2_DIR.mkdir(parents=True, exist_ok=True)
DARTS_MLLABELS_DIR.mkdir(parents=True, exist_ok=True)
ERA5_DIR.mkdir(parents=True, exist_ok=True) ERA5_DIR.mkdir(parents=True, exist_ok=True)
ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True) ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True)
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
@ -39,10 +42,6 @@ DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
dartsl2_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet"
dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet"
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
def _get_gridname(grid: Grid, level: int) -> str: def _get_gridname(grid: Grid, level: int) -> str:
return f"permafrost_{grid}{level}" return f"permafrost_{grid}{level}"
@ -60,13 +59,18 @@ def get_grid_viz_file(grid: Grid, level: int) -> Path:
return vizfile return vizfile
def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path: def get_darts_file(grid: Grid, level: int, version: Literal["v1", "v1-l3", "v2", "v2-l3", "mllabels"]) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
if labels: match version:
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet" case "v1" | "v1-l3":
else: darts_file = DARTS_V1_DIR / f"{gridname}_darts-{version}.zarr"
rtsfile = RTS_DIR / f"{gridname}_darts.parquet" case "v2" | "v2-l3":
return rtsfile darts_file = DARTS_V2_DIR / f"{gridname}_darts-{version}.zarr"
case "mllabels":
darts_file = DARTS_MLLABELS_DIR / f"{gridname}_darts-{version}.zarr"
case _:
raise ValueError(f"Unknown DARTS version: {version}")
return darts_file
def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path: def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
@ -85,13 +89,20 @@ def get_era5_stores(
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily", agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
grid: Grid | None = None, grid: Grid | None = None,
level: int | None = None, level: int | None = None,
temporal: Literal["synopsis"] | None = None,
) -> Path: ) -> Path:
if grid is None or level is None: pdir = ERA5_DIR
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True) if temporal is not None:
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr" agg += f"_{temporal}" # ty:ignore[invalid-assignment]
fname = f"{agg}_climate.zarr"
if grid is None or level is None:
pdir = pdir / "intermediate"
pdir.mkdir(parents=True, exist_ok=True)
else:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr" fname = f"{gridname}_{fname}"
aligned_path = pdir / fname
return aligned_path return aligned_path
@ -122,6 +133,12 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int
return cache_file return cache_file
def get_features_cache(ensemble_id: str, cells_hash: str) -> Path:
cache_dir = DATASET_ENSEMBLES_DIR / "cache" / "features" / ensemble_id
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / f"cells_{cells_hash}.parquet"
def get_cv_results_dir( def get_cv_results_dir(
name: str, name: str,
grid: Grid, grid: Grid,
@ -133,3 +150,15 @@ def get_cv_results_dir(
results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}" results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True) results_dir.mkdir(parents=True, exist_ok=True)
return results_dir return results_dir
def get_autogluon_results_dir(
grid: Grid,
level: int,
task: Task,
) -> Path:
gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
results_dir = RESULTS_DIR / f"{gridname}_autogluon_{now}_{task}"
results_dir.mkdir(parents=True, exist_ok=True)
return results_dir

View file

@ -15,11 +15,14 @@ type GridLevel = Literal[
"healpix9", "healpix9",
"healpix10", "healpix10",
] ]
type TargetDataset = Literal["darts_rts", "darts_mllabels"] type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count", "density"] type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"]
# TODO: Consider implementing a "timeseries" temporal mode
type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
type Model = Literal["espa", "xgboost", "rf", "knn"] type Model = Literal["espa", "xgboost", "rf", "knn"]
type Stage = Literal["train", "inference", "visualization"]
@dataclass(frozen=True) @dataclass(frozen=True)
@ -70,8 +73,9 @@ class GridConfig:
# Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists # Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists
all_tasks: list[Task] = ["binary", "count", "density"] all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"]
all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"] all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024]
all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"]
all_l2_source_datasets: list[L2SourceDataset] = [ all_l2_source_datasets: list[L2SourceDataset] = [
"ArcticDEM", "ArcticDEM",
"ERA5-shoulder", "ERA5-shoulder",