From 231caa62e7f979f76553cc45ee2d72763ae8c7e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 11 Jan 2026 15:57:14 +0100 Subject: [PATCH] Refactor dataset ensemble to allow different temporal modes --- scripts/01darts.sh | 30 +- src/entropice/ingest/darts.py | 259 ++++++++--- src/entropice/ml/dataset.py | 823 +++++++++++++++++---------------- src/entropice/spatial/grids.py | 24 +- src/entropice/utils/paths.py | 65 ++- src/entropice/utils/types.py | 12 +- 6 files changed, 708 insertions(+), 505 deletions(-) diff --git a/scripts/01darts.sh b/scripts/01darts.sh index 720d860..88da9bb 100644 --- a/scripts/01darts.sh +++ b/scripts/01darts.sh @@ -1,15 +1,27 @@ #! /bin/bash # pixi shell -darts extract-darts-rts --grid hex --level 3 -darts extract-darts-rts --grid hex --level 4 -darts extract-darts-rts --grid hex --level 5 -darts extract-darts-rts --grid hex --level 6 -darts extract-darts-rts --grid healpix --level 6 -darts extract-darts-rts --grid healpix --level 7 -darts extract-darts-rts --grid healpix --level 8 -darts extract-darts-rts --grid healpix --level 9 -darts extract-darts-rts --grid healpix --level 10 +darts extract-darts-v1 --grid hex --level 3 +darts extract-darts-v1 --grid hex --level 4 +darts extract-darts-v1 --grid hex --level 5 +darts extract-darts-v1 --grid hex --level 6 +darts extract-darts-v1 --grid healpix --level 6 +darts extract-darts-v1 --grid healpix --level 7 +darts extract-darts-v1 --grid healpix --level 8 +darts extract-darts-v1 --grid healpix --level 9 +darts extract-darts-v1 --grid healpix --level 10 + + +darts extract-darts-v1-aggregated --grid hex --level 3 +darts extract-darts-v1-aggregated --grid hex --level 4 +darts extract-darts-v1-aggregated --grid hex --level 5 +darts extract-darts-v1-aggregated --grid hex --level 6 +darts extract-darts-v1-aggregated --grid healpix --level 6 +darts extract-darts-v1-aggregated --grid healpix --level 7 +darts extract-darts-v1-aggregated --grid healpix --level 8 +darts extract-darts-v1-aggregated --grid healpix --level 9 +darts extract-darts-v1-aggregated --grid healpix --level 10 + darts extract-darts-mllabels --grid hex --level 3 darts extract-darts-mllabels --grid hex --level 4 diff --git a/src/entropice/ingest/darts.py b/src/entropice/ingest/darts.py index 4e3dfd1..819bd2e 100644 --- a/src/entropice/ingest/darts.py +++ b/src/entropice/ingest/darts.py @@ -9,28 +9,140 @@ Date: October 2025 import cyclopts import geopandas as gpd import pandas as pd -from rich import pretty, print, traceback -from rich.progress import track +import xarray as xr +import xdggs +from rich import pretty, traceback from stopuhr import stopwatch from entropice.spatial import grids from entropice.utils.paths import ( - darts_ml_training_labels_repo, - dartsl2_cov_file, - dartsl2_file, - get_darts_rts_file, + DARTS_MLLABELS_DIR, + DARTS_V1_DIR, + get_darts_file, ) from entropice.utils.types import Grid traceback.install() pretty.install() +darts_v1_l2_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet" +darts_v1_l2_cov_file = DARTS_V1_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet" +darts_ml_training_labels_repo = DARTS_MLLABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps" + cli = cyclopts.App(name="darts-rts") +def _load_grid(grid: Grid, level: int) -> tuple[gpd.GeoDataFrame, xr.DataArray]: + grid_gdf = grids.open(grid, level) + if grid == "hex": + grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)) + cell_areas = xr.DataArray( + grid_gdf["cell_area"].to_numpy(), dims=["cell_id"], coords={"cell_id": grid_gdf["cell_id"].to_numpy()} + ) + # We only want to use the geometry and the cell_id + grid_gdf = grid_gdf[["cell_id", "geometry"]] + return grid_gdf, cell_areas + + +@stopwatch("Convert xdgss") +def _convert_xdggs(darts: xr.Dataset, grid: Grid, level: int) -> xr.Dataset: + darts = darts.rename({"cell_id": "cell_ids"}) + darts["cell_ids"] = darts["cell_ids"].astype("uint64") + gridinfo = { + "grid_name": "h3" if grid == "hex" else grid, + "level": level, + } + if grid == "healpix": + gridinfo["indexing_scheme"] = "nested" + darts.cell_ids.attrs = gridinfo + darts = xdggs.decode(darts) + return darts + + +@stopwatch("Process RTS grid") +def _process_rts_grid( + rts: gpd.GeoDataFrame, + cov: gpd.GeoDataFrame, + cell_areas: xr.DataArray, +) -> xr.Dataset: + assert "cell_id" in rts.columns, "RTS data must contain cell_id column." + assert "cell_id" in cov.columns, "Coverage data must contain cell_id column." + + cov["area_km2"] = cov.geometry.area / 1e6 + covered_area: pd.Series = cov.pivot_table( + index="cell_id", + values="area_km2", + aggfunc="sum", + fill_value=0, + )["area_km2"] + covered_area: xr.DataArray = xr.DataArray.from_series(covered_area) + coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id) + rts["area_km2"] = rts.geometry.area / 1e6 + rts = rts.pivot_table( + index="cell_id", + values=["geometry", "area_km2"], + fill_value=0, + aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type] + ) + rts: xr.Dataset = xr.Dataset.from_dataframe(rts) + darts = xr.merge( + [rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0 + ) + darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0) + return darts + + +@stopwatch("Process RTS yearly grid") +def _process_rts_yearly_grid( + rts: gpd.GeoDataFrame, + cov: gpd.GeoDataFrame, + cell_areas: xr.DataArray, +) -> xr.Dataset: + assert "cell_id" in rts.columns, "RTS data must contain cell_id column." + assert "cell_id" in cov.columns, "Coverage data must contain cell_id column." + + cov["area_km2"] = cov.geometry.area / 1e6 + covered_area: pd.Series = cov.pivot_table( # type: ignore[var-annotated] # noqa: PD010 + index="cell_id", + columns="year", + values="area_km2", + aggfunc="sum", + fill_value=0, + ).unstack() + covered_area: xr.DataArray = xr.DataArray.from_series(covered_area) + coverage = covered_area / cell_areas.sel(cell_id=covered_area.cell_id) + rts: pd.DataFrame = ( + rts.pivot_table( # type: ignore[var-annotated] # noqa: PD013 + index="cell_id", + columns="year", + values=["geometry", "area_km2"], + fill_value=0, + aggfunc={"geometry": "count", "area_km2": "sum"}, # type: ignore[arg-type] + ) + .stack(future_stack=True) + .rename(columns={"geometry": "count"}) + ) + rts: xr.Dataset = xr.Dataset.from_dataframe(rts) + darts = xr.merge( + [rts, covered_area.rename("covered_area_km2"), coverage.rename("coverage")], join="outer", fill_value=0 + ) + darts["density"] = (darts["area_km2"] / darts["covered_area_km2"]).fillna(0) + return darts + + @cli.command() -def extract_darts_rts(grid: Grid, level: int): - """Extract RTS labels from DARTS dataset. +def extract_darts_v1(grid: Grid, level: int): + """Extract RTS labels from DARTS-v1 Level-2 dataset. + + Creates a Darts-v1 xarray Dataset on the specified grid and level. + The Dataset contains the following variables: + - count: Number of RTS in the cell + - area_km2: Total area of RTS in the cell (in km^2) + - covered_area_km2: Area of the cell covered by DARTS (in km^2) + - coverage: Fraction of the cell covered by DARTS + - density: Density of RTS area per covered area (area_km2 / covered_area_km2) + Since the DARTS-v1 Level-2 dataset contains yearly data, all variables are indexed by year as well. + Thus each variable has dimensions (cell_ids, year). Args: grid (Grid): The grid type to use. @@ -38,66 +150,81 @@ def extract_darts_rts(grid: Grid, level: int): """ with stopwatch("Load data"): - darts_l2 = gpd.read_parquet(dartsl2_file) - darts_cov_l2 = gpd.read_parquet(dartsl2_cov_file) - grid_gdf = grids.open(grid, level) + darts_l2 = gpd.read_parquet(darts_v1_l2_file) + darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file) + grid_gdf, cell_areas = _load_grid(grid, level) - with stopwatch("Extract RTS labels"): + with stopwatch("Assign RTS to grid"): grid_l2 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection") grid_cov_l2 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection") - years = list(grid_cov_l2["year"].unique()) - for year in track(years, total=len(years), description="Processing years..."): - with stopwatch("Processing RTS", log=False): - subset = grid_l2[grid_l2["year"] == year] - subset_cov = grid_cov_l2[grid_cov_l2["year"] == year] + darts = _process_rts_yearly_grid(grid_l2, grid_cov_l2, cell_areas) - counts = subset.groupby("cell_id").size() - grid_gdf[f"darts_{year}_rts_count"] = grid_gdf.cell_id.map(counts) + darts = _convert_xdggs(darts, grid, level) + output_path = get_darts_file(grid, level, version="v1") + with stopwatch(f"Writing Darts v1 to {output_path}"): + darts.to_zarr(output_path, consolidated=False, mode="w") - areas = subset.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False) - grid_gdf[f"darts_{year}_rts_area"] = grid_gdf.cell_id.map(areas) - areas_cov = subset_cov.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False) - grid_gdf[f"darts_{year}_covered_area"] = grid_gdf.cell_id.map(areas_cov) - grid_gdf[f"darts_{year}_coverage"] = grid_gdf[f"darts_{year}_covered_area"] / grid_gdf.geometry.area +@cli.command() +def extract_darts_v1_aggregated(grid: Grid, level: int): + """Extract RTS labels from DARTS-v1 Level-3 dataset. - grid_gdf[f"darts_{year}_rts_density"] = ( - grid_gdf[f"darts_{year}_rts_area"] / grid_gdf[f"darts_{year}_covered_area"] - ) + Creates a Darts-v1 xarray Dataset on the specified grid and level. + The Dataset contains the following variables: + - count: Number of RTS in the cell + - area_km2: Total area of RTS in the cell (in km^2) + - covered_area_km2: Area of the cell covered by DARTS (in km^2) + - coverage: Fraction of the cell covered by DARTS + - density: Density of RTS area per covered area (area_km2 / covered_area_km2) + Since the DARTS-v1 Level-2 dataset contains yearly data, the data is dissolved then exploded to obtain Level-3 data. + Thus each variable has only the dimension (cell_ids). - # Apply corrections to NaNs - covered = ~grid_gdf[f"darts_{year}_coverage"].isna() - grid_gdf.loc[covered, f"darts_{year}_rts_count"] = grid_gdf.loc[covered, f"darts_{year}_rts_count"].fillna( - 0.0 - ) - grid_gdf.loc[covered, f"darts_{year}_rts_density"] = grid_gdf.loc[ - covered, f"darts_{year}_rts_density" - ].fillna(0.0) - grid_gdf[f"darts_{year}_has_coverage"] = covered - grid_gdf[f"darts_{year}_has_rts"] = grid_gdf[f"darts_{year}_rts_count"] > 0 + Args: + grid (Grid): The grid type to use. + level (int): The grid level to use. - grid_gdf["darts_has_coverage"] = grid_gdf[[f"darts_{year}_coverage" for year in years]].any(axis=1) - grid_gdf["darts_has_rts"] = grid_gdf[[f"darts_{year}_rts_count" for year in years]].any(axis=1) + """ + with stopwatch("Load data"): + darts_l2 = gpd.read_parquet(darts_v1_l2_file) + darts_cov_l2 = gpd.read_parquet(darts_v1_l2_cov_file) + grid_gdf, cell_areas = _load_grid(grid, level) + # Remove overlapping labels by dissolving + darts_l2 = darts_l2[["geometry"]].dissolve().explode() + darts_cov_l2 = darts_cov_l2[["geometry"]].dissolve().explode() - darts_counts_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_count")] - darts_counts = grid_gdf[darts_counts_columns] - grid_gdf["darts_rts_count"] = darts_counts.dropna(axis=0, how="all").sum(axis=1) + with stopwatch("Extract RTS labels"): + grid_l3 = grid_gdf.overlay(darts_l2.to_crs(grid_gdf.crs), how="intersection") + grid_cov_l3 = grid_gdf.overlay(darts_cov_l2.to_crs(grid_gdf.crs), how="intersection") - darts_density_columns = [c for c in grid_gdf.columns if c.startswith("darts_") and c.endswith("_rts_density")] - darts_density = grid_gdf[darts_density_columns] - grid_gdf["darts_rts_density"] = darts_density.dropna(axis=0, how="all").max(axis=1) - - output_path = get_darts_rts_file(grid, level) - grid_gdf.to_parquet(output_path) - print(f"Saved RTS labels to {output_path}") - stopwatch.summary() + darts = _process_rts_grid(grid_l3, grid_cov_l3, cell_areas) + darts = _convert_xdggs(darts, grid, level) + output_path = get_darts_file(grid, level, version="v1-l3") + with stopwatch(f"Writing Darts v1 l3 to {output_path}"): + darts.to_zarr(output_path, consolidated=False, mode="w") @cli.command() def extract_darts_mllabels(grid: Grid, level: int): + """Extract RTS labels from the DARTS-mllabels dataset. + + Creates a Darts-mllabels xarray Dataset on the specified grid and level. + The Dataset contains the following variables: + - count: Number of RTS in the cell + - area_km2: Total area of RTS in the cell (in km^2) + - covered_area_km2: Area of the cell covered by DARTS (in km^2) + - coverage: Fraction of the cell covered by DARTS + - density: Density of RTS area per covered area (area_km2 / covered_area_km2) + Since the DARTS-mllabels dataset contains not much data, all variables are aggregated over the entire time period. + Thus each variable has only the dimension (cell_ids). + + Args: + grid (Grid): The grid type to use. + level (int): The grid level to use. + + """ with stopwatch("Load data"): - grid_gdf = grids.open(grid, level) + grid_gdf, cell_areas = _load_grid(grid, level) darts_mllabels = ( gpd.GeoDataFrame( pd.concat([gpd.read_file(f) for f in darts_ml_training_labels_repo.glob("**/TrainingLabel*.gpkg")]) @@ -118,34 +245,16 @@ def extract_darts_mllabels(grid: Grid, level: int): .to_crs(grid_gdf.crs) ) darts_cov_mllabels = darts_cov_mllabels[["geometry"]].dissolve().explode() + with stopwatch("Extract RTS labels"): grid_mllabels = grid_gdf.overlay(darts_mllabels.to_crs(grid_gdf.crs), how="intersection") grid_cov_mllabels = grid_gdf.overlay(darts_cov_mllabels.to_crs(grid_gdf.crs), how="intersection") - with stopwatch("Processing RTS"): - counts = grid_mllabels.groupby("cell_id").size() - grid_gdf["dartsml_rts_count"] = grid_gdf.cell_id.map(counts) - - areas = grid_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False) - grid_gdf["dartsml_rts_area"] = grid_gdf.cell_id.map(areas) - - areas_cov = grid_cov_mllabels.groupby("cell_id").apply(lambda x: x.geometry.area.sum(), include_groups=False) - grid_gdf["dartsml_covered_area"] = grid_gdf.cell_id.map(areas_cov) - grid_gdf["dartsml_coverage"] = grid_gdf["dartsml_covered_area"] / grid_gdf.geometry.area - grid_gdf["dartsml_rts_density"] = grid_gdf["dartsml_rts_area"] / grid_gdf["dartsml_covered_area"] - - # Apply corrections to NaNs - covered = ~grid_gdf["dartsml_coverage"].isna() - grid_gdf.loc[covered, "dartsml_rts_count"] = grid_gdf.loc[covered, "dartsml_rts_count"].fillna(0.0) - grid_gdf.loc[covered, "dartsml_rts_density"] = grid_gdf.loc[covered, "dartsml_rts_density"].fillna(0.0) - - grid_gdf["dartsml_has_coverage"] = covered - grid_gdf["dartsml_has_rts"] = grid_gdf["dartsml_rts_count"] > 0 - - output_path = get_darts_rts_file(grid, level, labels=True) - grid_gdf.to_parquet(output_path) - print(f"Saved RTS labels to {output_path}") - stopwatch.summary() + darts = _process_rts_grid(grid_mllabels, grid_cov_mllabels, cell_areas) + darts = _convert_xdggs(darts, grid, level) + output_path = get_darts_file(grid, level, version="mllabels") + with stopwatch(f"Writing Darts v1 to {output_path}"): + darts.to_zarr(output_path, consolidated=False, mode="w") def main(): # noqa: D103 diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py index 28ce0ba..1fa3608 100644 --- a/src/entropice/ml/dataset.py +++ b/src/entropice/ml/dataset.py @@ -1,4 +1,4 @@ -# ruff: noqa: N806 +# ruff: noqa: N806, D105 """Training dataset preparation and model training. Naming conventions: @@ -16,9 +16,9 @@ import hashlib import json from collections.abc import Generator from dataclasses import asdict, dataclass, field -from functools import cached_property +from functools import cache, cached_property from itertools import product -from typing import Literal, TypedDict +from typing import Literal, cast import cupy as cp import cyclopts @@ -33,8 +33,9 @@ from sklearn import set_config from sklearn.model_selection import train_test_split from stopuhr import stopwatch +import entropice.spatial.grids import entropice.utils.paths -from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task +from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task, TemporalMode traceback.install() pretty.install() @@ -44,42 +45,53 @@ set_config(array_api_dispatch=True) sns.set_theme("talk", "whitegrid") -covcol: dict[TargetDataset, str] = { - "darts_rts": "darts_has_coverage", - "darts_mllabels": "dartsml_has_coverage", -} +def _unstack_era5_time(era5: xr.Dataset, aggregation: Literal["yearly", "seasonal", "shoulder"]) -> xr.Dataset: + # In the yearly case, no unstacking is necessary, we can just rename the time dimension to year and change the coord + if aggregation == "yearly": + era5 = era5.rename({"time": "year"}) + era5.coords["year"] = era5["year"].dt.year + return era5 -taskcol: dict[Task, dict[TargetDataset, str]] = { - "binary": { - "darts_rts": "darts_has_rts", - "darts_mllabels": "dartsml_has_rts", - }, - "count": { - "darts_rts": "darts_rts_count", - "darts_mllabels": "dartsml_rts_count", - }, - "density": { - "darts_rts": "darts_rts_density", - "darts_mllabels": "dartsml_rts_density", - }, -} + # Make the time index a MultiIndex of year and month + era5.coords["year"] = era5.time.dt.year + era5.coords["month"] = era5.time.dt.month + era5["time"] = pd.MultiIndex.from_arrays( + [ + era5.time.dt.year.values, # noqa: PD011 + era5.time.dt.month.values, # noqa: PD011 + ], + names=("year", "month"), + ) + era5 = era5.unstack("time") # noqa: PD010 + seasons = {10: "winter", 4: "summer"} + shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} + month_map = seasons if aggregation == "seasonal" else shoulder_seasons + era5.coords["month"] = era5["month"].to_series().map(month_map) + return era5 -def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): - time_index = pd.DatetimeIndex(df.index.get_level_values("time")) - if temporal == "yearly": - return time_index.year - elif temporal == "seasonal": - seasons = {10: "winter", 4: "summer"} - return time_index.month.map(lambda x: seasons.get(x)).str.cat(time_index.year.astype(str), sep="_") - elif temporal == "shoulder": - shoulder_seasons = {10: "OND", 1: "JFM", 4: "AMJ", 7: "JAS"} - return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_") +def _collapse_to_dataframe(ds: xr.Dataset | xr.DataArray) -> pd.DataFrame: + collapsed = ds.to_dataframe() + # Make a dummy row to avoid empty dataframe issues + use_dummy = collapsed.shape[0] == 0 + if use_dummy: + collapsed.loc[tuple(range(len(collapsed.index.names)))] = np.nan + pivcols = set(collapsed.index.names) - {"cell_ids"} + collapsed = collapsed.pivot_table(index="cell_ids", columns=pivcols) + collapsed.columns = ["_".join(v) for v in collapsed.columns] + if use_dummy: + collapsed = collapsed.dropna(how="all") + return collapsed + + +def _cell_ids_hash(cell_ids: pd.Series) -> str: + sorted_ids = np.sort(cell_ids.to_numpy()) + return hashlib.blake2b(sorted_ids.tobytes(), digest_size=8).hexdigest() def bin_values( values: pd.Series, - task: Literal["count", "density"], + task: Literal["count_regimes", "density_regimes"], none_val: float = 0, ) -> pd.Series: """Bin values into predefined intervals for different tasks. @@ -89,7 +101,7 @@ def bin_values( Args: values (pd.Series): Pandas Series of numerical values to bin. - task (Literal["count", "density"]): Task type - 'count' or 'density'. + task (Literal["count_regimes", "density_regimes"]): Task type - 'count_regimes' or 'density_regimes'. none_val (float, optional): Value representing 'none' or 'empty' (e.g., 0 for count). Defaults to 0. Returns: @@ -100,15 +112,8 @@ def bin_values( """ labels_dict = { - "count": ["None", "Very Few", "Few", "Several", "Many", "Very Many"], - "density": [ - "Empty", - "Very Sparse", - "Sparse", - "Moderate", - "Dense", - "Very Dense", - ], + "count_regimes": ["None", "Very Few", "Few", "Several", "Many", "Very Many"], + "density_regimes": ["Empty", "Very Sparse", "Sparse", "Moderate", "Dense", "Very Dense"], } labels = labels_dict[task] @@ -138,62 +143,67 @@ def bin_values( @dataclass(frozen=True, eq=False) -class DatasetLabels: - binned: pd.Series +class SplittedArrays: + """Small wrapper for train and test arrays.""" + train: torch.Tensor | np.ndarray | cp.ndarray test: torch.Tensor | np.ndarray | cp.ndarray - raw_values: pd.Series + + +@dataclass(frozen=True, eq=False) +class TrainingSet: + """Container for the training dataset.""" + + targets: gpd.GeoDataFrame + features: pd.DataFrame + X: SplittedArrays + y: SplittedArrays + z: pd.Series + split: pd.Series @cached_property - def intervals(self) -> list[tuple[float, float] | tuple[int, int]]: + def target_intervals(self) -> list[tuple[float, float] | tuple[int, int] | tuple[None, None]]: + """Calculate the intervals for each target category. + + Returns: + list[tuple[float, float] | tuple[int, int] | tuple[None, None]]: + List of (min, max) tuples for each category. + + """ + binned = self.targets["y"] + raw = self.targets["z"] + assert binned.dtype.name == "category", "Target labels are not categorical." + # For each category get the min and max values from raw_values - intervals = [] - for category in self.binned.cat.categories: - category_mask = self.binned == category + intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] = [] + for category in binned.cat.categories: + category_mask = binned == category if category_mask.sum() == 0: intervals.append((None, None)) else: - category_raw_values = self.raw_values[category_mask] + category_raw_values = raw[category_mask] intervals.append((category_raw_values.min(), category_raw_values.max())) return intervals @cached_property - def labels(self) -> list[str]: - return list(self.binned.cat.categories) - - -@dataclass(frozen=True, eq=False) -class DatasetInputs: - data: pd.DataFrame - train: torch.Tensor | np.ndarray | cp.ndarray - test: torch.Tensor | np.ndarray | cp.ndarray - - -@dataclass(frozen=True) -class CategoricalTrainingDataset: - dataset: gpd.GeoDataFrame - X: DatasetInputs - y: DatasetLabels - z: pd.Series - split: pd.Series + def target_labels(self) -> list[str]: + """Labels of the target categories.""" + binned = self.targets["y"] + assert binned.dtype.name == "category", "Target labels are not categorical." + return list(binned.cat.categories) def __len__(self): return len(self.z) -class DatasetStats(TypedDict): - target: str - num_target_samples: int - members: dict[str, dict[str, object]] - total_features: int - - @cyclopts.Parameter("*") @dataclass(frozen=True, kw_only=True) class DatasetEnsemble: + """An ensemble of datasets for training and inference.""" + grid: Grid level: int - target: Literal["darts_rts", "darts_mllabels"] + target: TargetDataset members: list[L2SourceDataset] = field( default_factory=lambda: [ "AlphaEarth", @@ -203,366 +213,420 @@ class DatasetEnsemble: "ERA5-shoulder", ] ) - dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) - variable_filters: dict[str, list[str]] = field(default_factory=dict) - filter_target: str | Literal[False] = False + temporal_mode: TemporalMode = "synopsis" + dimension_filters: dict[L2SourceDataset, dict[str, list]] = field(default_factory=dict) + variable_filters: dict[L2SourceDataset, list[str]] = field(default_factory=dict) add_lonlat: bool = True + def __post_init__(self): + # Validate filters + for member in self.members: + if member in self.variable_filters: + # This enforces the assumption that the read members always are xarray Datasets and not DataArrays + assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, ( + f"Invalid variable filter for {member=}: {self.variable_filters[member]}" + " Variable filter values must be a list with one or more entries." + ) + if member in self.dimension_filters: + for dim, values in self.dimension_filters[member].items(): + # This enforces the assumption that we do the temporal filtering via the temporal_mode + assert dim not in ["year", "month", "time"], ( + f"Invalid dimension filter for {member=}: {dim}" + " Filtering on 'year', 'month' or 'time' is not supported." + ) + # This enforces the assumption that there are no empty dimensions in the Dataset + assert isinstance(values, list) and len(values) >= 1, ( + f"Invalid dimension filter for {dim=}: {values}" + " Dimension filter values must be a list with one or more entries." + ) + def __hash__(self): return int(self.id(), 16) + @cache def id(self): + """Return an unique, stable identifier based on the settings of this class.""" return hashlib.blake2b( json.dumps(asdict(self), sort_keys=True).encode("utf-8"), digest_size=16, ).hexdigest() - @property - def covcol(self) -> str: - return covcol[self.target] + @cached_property + def cell_ids(self) -> pd.Series: + return self.read_grid()["cell_id"] - def taskcol(self, task: Task) -> str: - return taskcol[task][self.target] + @cached_property + def geometries(self) -> pd.Series: + return self.read_grid()["geometry"] - def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: - if member == "AlphaEarth": - store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level) - elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: - era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] - store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level) - elif member == "ArcticDEM": - store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level) - else: - raise NotImplementedError(f"Member {member} not implemented.") + # @stopwatch("Reading grid") + def read_grid(self) -> gpd.GeoDataFrame: + grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level) + # Add the lat / lon of the cell centers + if self.add_lonlat: + grid_gdf["lon"] = grid_gdf.geometry.centroid.x + grid_gdf["lat"] = grid_gdf.geometry.centroid.y + + # Convert hex cell_id to int + if self.grid == "hex": + grid_gdf["cell_id"] = grid_gdf["cell_id"].apply(lambda x: int(x, 16)).astype(np.uint64) + + grid_gdf = grid_gdf.set_index("cell_id") + return grid_gdf + + # @stopwatch.f("Getting target labels", print_kwargs=["stage"]) + def get_targets(self, task: Task) -> gpd.GeoDataFrame: + """Create a training target labels for a specific task. + + The function reads the target dataset, filters it based on coverage, + and prepares the target labels according to the specified task. + + This is quite a fast function, with the bottleneck usually being the reading of the zarr store + (milliseconds to few seconds). + + This is intended to be used by another training Dataset creation function + and for easy visualization of the targets. + + + Args: + task (Task): The task. + + Returns: + gpd.GeoDataFrame: GeoDataFrame with target labels, geometry and raw-target values, + all indexed by the cell_ids. + Columns: ["geometry", "y", "z"], where "y" is the target label and "z" the raw value. + For categorical tasks, "y" is categorical and "z" is numerical. + For regression tasks, both "y" and "z" are numerical and identical. + + """ + match (self.target, self.temporal_mode): + case ("darts_v1" | "darts_v2", "feature" | "synopsis"): + version: Literal["v1-l3", "v2-l3"] = self.target.split("_")[1] + "-l3" # ty:ignore[invalid-assignment] + target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version) + targets = xr.open_zarr(target_store, consolidated=False) + case ("darts_v1" | "darts_v2", int()): + version: Literal["v1", "v2"] = self.target.split("_")[1] # ty:ignore[invalid-assignment] + target_store = entropice.utils.paths.get_darts_file(grid=self.grid, level=self.level, version=version) + targets = xr.open_zarr(target_store, consolidated=False).sel(year=self.temporal_mode) + case ("darts_mllabels", str()): # Years are not supported + target_store = entropice.utils.paths.get_darts_file( + grid=self.grid, level=self.level, version="mllabels" + ) + targets = xr.open_zarr(target_store, consolidated=False) + case _: + raise NotImplementedError(f"Target {self.target} on {self.temporal_mode} mode not supported.") + targets = cast(xr.Dataset, targets) + covered_cell_ids = targets["coverage"].where(targets["coverage"] > 0).dropna("cell_ids")["cell_ids"].to_series() + targets = targets.sel(cell_ids=covered_cell_ids.to_numpy()) + match task: + case "binary": + z = targets["count"].to_series() + y = (z > 0).astype("category").map({False: "No RTS", True: "RTS"}) + case "count_regimes": + z = targets["count"].to_series() + y = bin_values(z, task="count_regimes", none_val=0) + case "density_regimes": + z = targets["density"].to_series() + y = bin_values(z, task="density_regimes", none_val=0) + case "count": + z = targets["count"].to_series() + y = z + case "density": + z = targets["density"].to_series() + y = z + case _: + raise NotImplementedError(f"Task {task} not supported.") + cell_ids = targets["cell_ids"].to_series() + geometries = self.geometries.loc[cell_ids] + return gpd.GeoDataFrame( + { + "cell_id": cell_ids, + "geometry": geometries, + "y": y, + "z": z, + } + ).set_index("cell_id") + + # @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"]) + def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901 + """Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration. + + When `lazy` is False, the data is actually read into memory, which can take some time (seconds). + + Args: + member (L2SourceDataset): The member + cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read. + Defaults to None. + lazy (bool, optional): Whether to not load the data and instead return a lazy dataset. Defaults to False. + + Returns: + xr.Dataset: Xarray Dataset with the member data for the given cell IDs. + + """ + match member: + case "AlphaEarth": + store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level) + case "ERA5-yearly" | "ERA5-seasonal" | "ERA5-shoulder": + era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] + store = entropice.utils.paths.get_era5_stores(era5_agg, grid=self.grid, level=self.level) + case "ArcticDEM": + store = entropice.utils.paths.get_arcticdem_stores(grid=self.grid, level=self.level) + case _: + raise NotImplementedError(f"Member {member} not implemented.") ds = xr.open_zarr(store, consolidated=False) - # Apply variable filters + # Apply variable and dimension filters if member in self.variable_filters: - assert isinstance(self.variable_filters[member], list) and len(self.variable_filters[member]) >= 1, ( - f"Invalid variable filter for {member=}: {self.variable_filters[member]}" - " Variable filter values must be a list with one or more entries." - ) ds = ds[self.variable_filters[member]] - - # Apply dimension filters if member in self.dimension_filters: - for dim, values in self.dimension_filters[member].items(): - assert isinstance(values, list) and len(values) >= 1, ( - f"Invalid dimension filter for {dim=}: {values}" - " Dimension filter values must be a list with one or more entries." - ) - ds = ds.sel({dim: values}) - + ds = ds.sel(self.dimension_filters[member]) # Delete all coordinates which are not in the dimension - for coord in ds.coords: - if coord not in ds.dims: - ds = ds.drop_vars(coord) + ds = ds.drop_vars([coord for coord in ds.coords if coord not in ds.dims]) # Only load target cell ids - intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(targets["cell_id"].values)) + if cell_ids is None: + cell_ids = self.cell_ids + intersecting_cell_ids = set(ds["cell_ids"].values).intersection(set(cell_ids.to_numpy())) ds = ds.sel(cell_ids=list(intersecting_cell_ids)) + # Unstack era5 data if needed + if member.startswith("ERA5"): + era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] + ds = _unstack_era5_time(ds, era5_agg) + + # Apply the temporal mode + match (member.split("-"), self.temporal_mode): + case (["ERA5", _] | ["AlphaEarth"], "synopsis"): + ds_mean = ds.mean(dim="year") + ds_trend = ds.polyfit(dim="year", deg=1).sel(degree=1, drop=True) + # Rename all cols from "{var}_polyfit_coefficients" to "{var}_trend" + ds_trend = ds_trend.rename( + {var: str(var).replace("_polyfit_coefficients", "_trend") for var in ds_trend.data_vars} + ) + ds = xr.merge([ds_mean, ds_trend]) + case (["ArcticDEM"], "synopsis"): + pass # No temporal dimension + case (_, int() as year): + ds = ds.sel(year=year, drop=True) + case (_, "feature"): + pass + case _: + raise NotImplementedError(f"Temporal mode {self.temporal_mode} not implemented for member {member}.") + # Actually read data into memory if not lazy: ds.load() return ds - def _read_target(self) -> gpd.GeoDataFrame: - if self.target == "darts_rts": - target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level) - elif self.target == "darts_mllabels": - target_store = entropice.utils.paths.get_darts_rts_file(grid=self.grid, level=self.level, labels=True) - else: - raise NotImplementedError(f"Target {self.target} not implemented.") - targets = gpd.read_parquet(target_store) + def make_features( + self, + cell_ids: pd.Series | None = None, + cache_mode: Literal["none", "read", "overwrite"] = "none", + ) -> pd.DataFrame: + """Create a feature DataFrame for the given temporal task and cell IDs. - # Filter to coverage - if self.filter_target: - targets = targets[targets[self.filter_target]] - # Convert hex cell_id to int - if self.grid == "hex": - targets["cell_id"] = targets["cell_id"].apply(lambda x: int(x, 16)) + This function reads all members of the ensemble, prepares their features based on the temporal task, + and combines them into a single DataFrame (no geometry). + It is quite computation intensive (seconds to minutes), depending on the configuration and number of cell IDs. + To speed up repeated calls, a caching mechanism is implemented. - # Add the lat / lon of the cell centers - if self.add_lonlat: - targets["lon"] = targets.geometry.centroid.x - targets["lat"] = targets.geometry.centroid.y + The indented use for this function is solely the creation of training and inference datasets. + For visualization purposes it is recommended to use the `read_member` with `lazy=True` functions directly. - return targets + Args: + cell_ids (pd.Series | None, optional): The cell IDs to read. If None, all cell IDs are read. + Defaults to None. + cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode. + "none": No caching. + "read": Read from cache if exists, otherwise create and save to cache. + "overwrite": Always create and save to cache, overwriting existing cache. + Defaults to "none". + Returns: + pd.DataFrame: DataFrame with features for the given cell IDs. + + """ + if cell_ids is None: + cell_ids = self.cell_ids + + # Caching mechanism + if cache_mode != "none": + cells_hash = _cell_ids_hash(cell_ids) + cache_file = entropice.utils.paths.get_features_cache(ensemble_id=self.id(), cells_hash=cells_hash) + if cache_mode == "read" and cache_file.exists(): + with stopwatch("Loading features from cache"): + return pd.read_parquet(cache_file) + + # Create features + dataset = self.read_grid().loc[cell_ids].drop(columns=["geometry"]) + member_dfs = [] + for member in self.members: + match member.split("-"): + case ["ERA5", era5_agg]: + member_dfs.append(self._prep_era5(cell_ids, era5_agg)) + case ["AlphaEarth"]: + member_dfs.append(self._prep_embeddings(cell_ids)) + case ["ArcticDEM"]: + member_dfs.append(self._prep_arcticdem(cell_ids)) + case _: + raise NotImplementedError(f"Member {member} not implemented.") + with stopwatch("Joining datasets"): + dataset = dataset.join(member_dfs) + print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") + + if cache_mode in ["read", "overwrite"]: + with stopwatch("Saving features to cache"): + dataset.to_parquet(cache_file) + return dataset + + # @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"]) def _prep_era5( self, - targets: gpd.GeoDataFrame, - temporal: Literal["yearly", "seasonal", "shoulder"], + cell_ids: pd.Series, + era5_agg: Literal["yearly", "seasonal", "shoulder"], ) -> pd.DataFrame: - era5 = self._read_member("ERA5-" + temporal, targets) - - if len(era5["cell_ids"]) == 0: - # No data for these cells - create empty DataFrame with expected columns - # Use the Dataset metadata to determine column structure - variables = list(era5.data_vars) - times = era5.coords["time"].to_numpy() - time_df = pd.DataFrame({"time": times}) - time_df.index = pd.DatetimeIndex(times) - tempus = _get_era5_tempus(time_df, temporal) - unique_tempus = tempus.unique() - - if "aggregations" in era5.dims: - aggs_list = era5.coords["aggregations"].to_numpy() - expected_cols = [ - f"era5_{var}_{t}_{agg}" for var, t, agg in product(variables, unique_tempus, aggs_list) - ] - else: - expected_cols = [f"era5_{var}_{t}" for var, t in product(variables, unique_tempus)] - - return pd.DataFrame( - index=targets["cell_id"].values, - columns=expected_cols, - dtype=float, - ) - era5_df = era5.to_dataframe() - era5_df["t"] = _get_era5_tempus(era5_df, temporal) - if "aggregations" not in era5.dims: - era5_df = era5_df.pivot_table(index="cell_ids", columns="t") - era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns] - else: - era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"]) - era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns] + era5 = self.read_member("ERA5-" + era5_agg, cell_ids=cell_ids, lazy=False) + era5_df = _collapse_to_dataframe(era5) + era5_df.columns = [f"era5_{col}" for col in era5_df.columns] # Ensure all target cell_ids are present, fill missing with NaN - era5_df = era5_df.reindex(targets["cell_id"].values, fill_value=np.nan) + era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) return era5_df - def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: - embeddings = self._read_member("AlphaEarth", targets)["embeddings"] - - if len(embeddings["cell_ids"]) == 0: - # No data for these cells - create empty DataFrame with expected columns - # Use the Dataset metadata to determine column structure - years = embeddings.coords["year"].to_numpy() - aggs = embeddings.coords["agg"].to_numpy() - bands = embeddings.coords["band"].to_numpy() - expected_cols = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in product(years, aggs, bands)] - return pd.DataFrame( - index=targets["cell_id"].values, - columns=expected_cols, - dtype=float, - ) - embeddings_df = embeddings.to_dataframe(name="value") - embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value") - embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] + # @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"]) + def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame: + embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"] + embeddings_df = _collapse_to_dataframe(embeddings) + embeddings_df.columns = [f"embeddings_{col}" for col in embeddings_df.columns] # Ensure all target cell_ids are present, fill missing with NaN - embeddings_df = embeddings_df.reindex(targets["cell_id"].values, fill_value=np.nan) + embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) return embeddings_df - def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: - arcticdem = self._read_member("ArcticDEM", targets) - + # @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"]) + def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame: + arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True) if len(arcticdem["cell_ids"]) == 0: # No data for these cells - create empty DataFrame with expected columns # Use the Dataset metadata to determine column structure variables = list(arcticdem.data_vars) aggs = arcticdem.coords["aggregations"].to_numpy() expected_cols = [f"arcticdem_{var}_{agg}" for var, agg in product(variables, aggs)] - return pd.DataFrame( - index=targets["cell_id"].values, - columns=expected_cols, - dtype=float, - ) + return pd.DataFrame(index=cell_ids.to_numpy(), columns=expected_cols, dtype=float) arcticdem_df = arcticdem.to_dataframe().pivot_table(index="cell_ids", columns="aggregations") arcticdem_df.columns = [f"arcticdem_{var}_{agg}" for var, agg in arcticdem_df.columns] # Ensure all target cell_ids are present, fill missing with NaN - arcticdem_df = arcticdem_df.reindex(targets["cell_id"].values, fill_value=np.nan) + arcticdem_df = arcticdem_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) return arcticdem_df - def get_stats(self) -> DatasetStats: - """Get dataset statistics. + def create_inference_df( + self, + batch_size: int | None = None, + cache_mode: Literal["none", "overwrite", "read"] = "read", + ) -> Generator[pd.DataFrame]: + """Create an inference feature set generator. - Returns: - DatasetStats: Dictionary containing target stats, member stats, and total features count. + This function creates features for all cell IDs in batches. + If `batch_size` is None or greater than the number of cell IDs, all data is returned in a single batch. + + Args: + batch_size (int | None, optional): The batch size. If None, all data is returned in a single batch. + Defaults to None. + cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation. + "none": No caching. + "read": Read from cache if exists, otherwise create and save to cache. + "overwrite": Always create and save to cache, overwriting existing cache. + Defaults to "none". + + Yields: + Generator[pd.DataFrame]: Generator yielding DataFrames with features for inference. """ - targets = self._read_target() - stats: DatasetStats = { - "target": self.target, - "num_target_samples": len(targets), - "members": {}, - "total_features": 2 if self.add_lonlat else 0, # Lat and Lon - } - - for member in self.members: - ds = self._read_member(member, targets, lazy=True) - n_cols_member = len(ds.data_vars) - for dim in ds.sizes: - if dim != "cell_ids": - n_cols_member *= ds.sizes[dim] - - stats["members"][member] = { - "variables": list(ds.data_vars), - "num_variables": len(ds.data_vars), - "dimensions": dict(ds.sizes), - "coordinates": list(ds.coords), - "num_features": n_cols_member, - } - stats["total_features"] += n_cols_member - - return stats - - def print_stats(self): - stats = self.get_stats() - print(f"=== Target: {stats['target']}") - print(f"\tNumber of target samples: {stats['num_target_samples']}") - - for member, member_stats in stats["members"].items(): - print(f"=== Member: {member}") - print(f"\tVariables ({member_stats['num_variables']}): {member_stats['variables']}") - print(f"\tDimensions: {member_stats['dimensions']}") - print(f"\tCoordinates: {member_stats['coordinates']}") - print(f"\tNumber of features from member: {member_stats['num_features']}") - - print(f"=== Total number of features in dataset: {stats['total_features']}") - - def create( - self, - filter_target_col: str | None = None, - cache_mode: Literal["n", "o", "r"] = "r", - ) -> gpd.GeoDataFrame: - # n: no cache, o: overwrite cache, r: read cache if exists - cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col) - if cache_mode == "r" and cache_file.exists(): - with stopwatch("Loading dataset from cache"): - dataset = gpd.read_parquet(cache_file) - print( - f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" - f" and {len(dataset.columns)} features." - ) - return dataset - with stopwatch("Reading target"): - targets = self._read_target() - if filter_target_col is not None: - targets = targets.loc[targets[filter_target_col]] - print(f"Read and filtered target dataset. ({len(targets)} samples)") - with stopwatch("Preparing member datasets"): - member_dfs = [] - for member in self.members: - if member.startswith("ERA5"): - era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] - member_dfs.append(self._prep_era5(targets, era5_agg)) - elif member == "AlphaEarth": - member_dfs.append(self._prep_embeddings(targets)) - elif member == "ArcticDEM": - member_dfs.append(self._prep_arcticdem(targets)) - else: - raise NotImplementedError(f"Member {member} not implemented.") - print("Prepared all member datasets. Joining...") - with stopwatch("Joining datasets"): - dataset = targets.set_index("cell_id").join(member_dfs) - print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") - print("Joining complete.") - - if cache_mode in ["o", "r"]: - with stopwatch("Saving dataset to cache"): - dataset.to_parquet(cache_file) - print(f"Saved dataset to cache at {cache_file}.") - return dataset - - def create_batches( - self, - batch_size: int, - filter_target_col: str | None = None, - cache_mode: Literal["n", "o", "r"] = "r", - ) -> Generator[pd.DataFrame]: - targets = self._read_target() - if len(targets) == 0: - raise ValueError("No target samples found.") - elif len(targets) < batch_size: - yield self.create(filter_target_col=filter_target_col, cache_mode=cache_mode) + all_cell_ids = self.cell_ids + if batch_size is None or batch_size >= len(all_cell_ids): + yield self.make_features(cell_ids=all_cell_ids, cache_mode=cache_mode) return - if filter_target_col is not None: - targets = targets.loc[targets[filter_target_col]] + for i in range(0, len(all_cell_ids), batch_size): + batch_cell_ids = all_cell_ids.iloc[i : i + batch_size] + yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode) - for i in range(0, len(targets), batch_size): - # n: no cache, o: overwrite cache, r: read cache if exists - cache_file = entropice.utils.paths.get_dataset_cache( - self.id(), subset=filter_target_col, batch=(i, i + batch_size) - ) - if cache_mode == "r" and cache_file.exists(): - dataset = gpd.read_parquet(cache_file) - print( - f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" - f" and {len(dataset.columns)} features." - ) - yield dataset - else: - targets_batch = targets.iloc[i : i + batch_size] - - member_dfs = [] - for member in self.members: - if member.startswith("ERA5"): - era5_agg: Literal["yearly", "seasonal", "shoulder"] = member.split("-")[1] # ty:ignore[invalid-assignment] - member_dfs.append(self._prep_era5(targets_batch, era5_agg)) - elif member == "AlphaEarth": - member_dfs.append(self._prep_embeddings(targets_batch)) - elif member == "ArcticDEM": - member_dfs.append(self._prep_arcticdem(targets_batch)) - else: - raise NotImplementedError(f"Member {member} not implemented.") - - dataset = targets_batch.set_index("cell_id").join(member_dfs) - print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") - if cache_mode in ["o", "r"]: - dataset.to_parquet(cache_file) - print(f"Saved dataset to cache at {cache_file}.") - yield dataset - - def _cat_and_split( + def create_training_df( self, - dataset: gpd.GeoDataFrame, task: Task, - device: Literal["cpu", "cuda", "torch"], - ) -> CategoricalTrainingDataset: - taskcol = self.taskcol(task) + cache_mode: Literal["none", "overwrite", "read"] = "read", + ) -> pd.DataFrame: + """Create a simple DataFrame for training with e.g. autogluon. - valid_labels = dataset[taskcol].notna() + This function creates a DataFrame with features and target labels (called "y") for training. + The resulting dataframe contains no geometry information and drops all samples with NaN values. + Does not split into train and test sets, also does not convert to arrays. + For more advanced use cases, use `create_training_set`. - cols_to_drop = {"geometry", taskcol, self.covcol} - cols_to_drop |= { - col - for col in dataset.columns - if col.startswith("dartsml_" if self.target == "darts_mllabels" else "darts_") - } + Args: + task (Task): The task. + cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation. + "none": No caching. + "read": Read from cache if exists, otherwise create and save to cache. + "overwrite": Always create and save to cache, overwriting existing cache. + Defaults to "none". - model_inputs = dataset.drop(columns=cols_to_drop) - # Assert that no column in all-nan - assert not model_inputs.isna().all("index").any(), "Some input columns are all NaN" - # Get valid inputs (rows) - valid_inputs = model_inputs.notna().all("columns") + Returns: + pd.DataFrame: The training DataFrame. - dataset = dataset.loc[valid_labels & valid_inputs] - model_inputs = model_inputs.loc[valid_labels & valid_inputs] - model_labels = dataset[taskcol] + """ + targets = self.get_targets(task) + assert len(targets) > 0, "No target samples found." + features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode) + dataset = targets[["y"]].join(features).dropna() + assert len(dataset) > 0, "No valid samples found after joining features and targets." + return dataset - if task == "binary": - binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category") - elif task == "count": - binned = bin_values(model_labels.astype(int), task=task) - elif task == "density": - binned = bin_values(model_labels, task=task) - else: - raise ValueError("Invalid task.") + def create_training_set( + self, + task: Task, + device: Literal["cpu", "cuda", "torch"] = "cpu", + cache_mode: Literal["none", "overwrite", "read"] = "read", + ) -> TrainingSet: + """Create a full training set for model training. + + This function creates a full training set with features and target labels, + splits it into train and test sets, converts to arrays and moves to the specified device. + + Args: + task (Task): The task. + device (Literal["cpu", "cuda", "torch"], optional): The device to move the data to. Defaults to "cpu". + cache_mode (Literal["none", "read", "overwrite"], optional): Caching mode for feature creation. + "none": No caching. + "read": Read from cache if exists, otherwise create and save to cache. + "overwrite": Always create and save to cache, overwriting existing cache. + Defaults to "none". + + Returns: + TrainingSet: The training set. + Contains a lot of useful information for training and evaluation. + + """ + targets = self.get_targets(task) + assert len(targets) > 0, "No target samples found." + features = self.make_features(cell_ids=targets.index.to_series(), cache_mode=cache_mode).dropna() + assert len(features) > 0, "No valid features found after dropping NaNs." + assert not features.isna().all("index").any(), "Some feature columns are all NaN" # Create train / test split - train_idx, test_idx = train_test_split(dataset.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True) - split = pd.Series(index=dataset.index, dtype=object) + train_idx, test_idx = train_test_split(features.index.to_numpy(), test_size=0.2, random_state=42, shuffle=True) + split = pd.Series(index=features.index, dtype=object) split.loc[train_idx] = "train" split.loc[test_idx] = "test" split = split.astype("category") - X_train = model_inputs.loc[train_idx].to_numpy(dtype="float64") - X_test = model_inputs.loc[test_idx].to_numpy(dtype="float64") - y_train = binned.loc[train_idx].cat.codes.to_numpy(dtype="int64") - y_test = binned.loc[test_idx].cat.codes.to_numpy(dtype="int64") + X_train = features.loc[train_idx].to_numpy(dtype="float64") + X_test = features.loc[test_idx].to_numpy(dtype="float64") + if task in ["binary", "count_regimes", "density_regimes"]: + y_train = targets.loc[train_idx, "y"].cat.codes.to_numpy(dtype="int64") + y_test = targets.loc[test_idx, "y"].cat.codes.to_numpy(dtype="int64") + else: + y_train = targets.loc[train_idx, "y"].to_numpy(dtype="float64") + y_test = targets.loc[test_idx, "y"].to_numpy(dtype="float64") if device == "cuda": X_train = cp.asarray(X_train) @@ -578,28 +642,13 @@ class DatasetEnsemble: y_test = torch.from_numpy(y_test).to(device=torch_device) print(f"Using torch device: {torch.cuda.get_device_name(X_train.device) if X_train.is_cuda else 'cpu'}") else: - assert device == "cpu", "Invalid device specified." + assert device == "cpu", f"Invalid {device=} specified." - return CategoricalTrainingDataset( - dataset=dataset.to_crs("EPSG:4326"), - X=DatasetInputs(data=model_inputs, train=X_train, test=X_test), - y=DatasetLabels(binned=binned, train=y_train, test=y_test, raw_values=model_labels), - z=model_labels, + return TrainingSet( + targets=targets.loc[features.index].to_crs("EPSG:4326"), + features=features, + X=SplittedArrays(train=X_train, test=X_test), + y=SplittedArrays(train=y_train, test=y_test), + z=targets.loc[features.index, "z"], split=split, ) - - def create_cat_training_dataset( - self, task: Task, device: Literal["cpu", "cuda", "torch"] - ) -> CategoricalTrainingDataset: - """Create a categorical dataset for training. - - Args: - task (Task): Task type. - device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto. - - Returns: - CategoricalTrainingDataset: The prepared categorical training dataset. - - """ - dataset = self.create(filter_target_col=self.covcol) - return self._cat_and_split(dataset, task, device) diff --git a/src/entropice/spatial/grids.py b/src/entropice/spatial/grids.py index 3a60d99..5f4a976 100644 --- a/src/entropice/spatial/grids.py +++ b/src/entropice/spatial/grids.py @@ -31,7 +31,7 @@ traceback.install() pretty.install() -def open(grid: Grid, level: int): +def open(grid: Grid, level: int) -> gpd.GeoDataFrame: """Open a saved grid from parquet file. Args: @@ -162,7 +162,7 @@ def create_global_hex_grid(resolution): grid = gpd.GeoDataFrame( {"cell_id": hex_id_list, "cell_area": hex_area_list, "geometry": hex_list}, crs="EPSG:4326", - ) + ) # ty:ignore[no-matching-overload] return grid @@ -193,7 +193,7 @@ def create_global_healpix_grid(level: int): geometry = healpix_ds.dggs.cell_boundaries() # Create GeoDataFrame - grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") + grid = gpd.GeoDataFrame({"cell_id": cell_ids, "geometry": geometry}, crs="EPSG:4326") # ty:ignore[no-matching-overload] grid["cell_area"] = grid.to_crs("EPSG:3413").geometry.area / 1e6 # Convert to km^2 @@ -250,18 +250,18 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure: """ fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={"projection": ccrs.NorthPolarStereo()}) - ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) + ax.set_extent([-180, 180, 50, 90], crs=ccrs.PlateCarree()) # ty:ignore[unresolved-attribute] # Add features - ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") - ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") - ax.add_feature(cfeature.COASTLINE) - ax.add_feature(cfeature.BORDERS, linestyle=":") - ax.add_feature(cfeature.LAKES, alpha=0.5) - ax.add_feature(cfeature.RIVERS) + ax.add_feature(cfeature.LAND, zorder=0, edgecolor="black", facecolor="white") # ty:ignore[unresolved-attribute] + ax.add_feature(cfeature.OCEAN, zorder=0, facecolor="lightgrey") # ty:ignore[unresolved-attribute] + ax.add_feature(cfeature.COASTLINE) # ty:ignore[unresolved-attribute] + ax.add_feature(cfeature.BORDERS, linestyle=":") # ty:ignore[unresolved-attribute] + ax.add_feature(cfeature.LAKES, alpha=0.5) # ty:ignore[unresolved-attribute] + ax.add_feature(cfeature.RIVERS) # ty:ignore[unresolved-attribute] # Add gridlines - gl = ax.gridlines(draw_labels=True) + gl = ax.gridlines(draw_labels=True) # ty:ignore[unresolved-attribute] gl.top_labels = False gl.right_labels = False @@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure: verts = np.vstack([np.sin(theta), np.cos(theta)]).T circle = mpath.Path(verts * radius + center) - ax.set_boundary(circle, transform=ax.transAxes) + ax.set_boundary(circle, transform=ax.transAxes) # ty:ignore[unresolved-attribute] return fig diff --git a/src/entropice/utils/paths.py b/src/entropice/utils/paths.py index 3c688c3..7708845 100644 --- a/src/entropice/utils/paths.py +++ b/src/entropice/utils/paths.py @@ -15,8 +15,9 @@ DATA_DIR = Path("/raid/scratch/tohoel001/data/entropice") # Temporary hardcodin GRIDS_DIR = DATA_DIR / "grids" FIGURES_DIR = Path("figures") -RTS_DIR = DATA_DIR / "darts-rts" -RTS_LABELS_DIR = DATA_DIR / "darts-rts-mllabels" +DARTS_V1_DIR = DATA_DIR / "darts-v1" +DARTS_V2_DIR = DATA_DIR / "darts-v2" +DARTS_MLLABELS_DIR = DATA_DIR / "darts-mllabels" ERA5_DIR = DATA_DIR / "era5" ARCTICDEM_DIR = DATA_DIR / "arcticdem" EMBEDDINGS_DIR = DATA_DIR / "embeddings" @@ -27,7 +28,9 @@ RESULTS_DIR = DATA_DIR / "results" GRIDS_DIR.mkdir(parents=True, exist_ok=True) FIGURES_DIR.mkdir(parents=True, exist_ok=True) -RTS_DIR.mkdir(parents=True, exist_ok=True) +DARTS_V1_DIR.mkdir(parents=True, exist_ok=True) +DARTS_V2_DIR.mkdir(parents=True, exist_ok=True) +DARTS_MLLABELS_DIR.mkdir(parents=True, exist_ok=True) ERA5_DIR.mkdir(parents=True, exist_ok=True) ARCTICDEM_DIR.mkdir(parents=True, exist_ok=True) EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) @@ -39,10 +42,6 @@ DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True) watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" -dartsl2_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_features_2018-2023_level2.parquet" -dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.parquet" -darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps" - def _get_gridname(grid: Grid, level: int) -> str: return f"permafrost_{grid}{level}" @@ -60,13 +59,18 @@ def get_grid_viz_file(grid: Grid, level: int) -> Path: return vizfile -def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path: +def get_darts_file(grid: Grid, level: int, version: Literal["v1", "v1-l3", "v2", "v2-l3", "mllabels"]) -> Path: gridname = _get_gridname(grid, level) - if labels: - rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet" - else: - rtsfile = RTS_DIR / f"{gridname}_darts.parquet" - return rtsfile + match version: + case "v1" | "v1-l3": + darts_file = DARTS_V1_DIR / f"{gridname}_darts-{version}.zarr" + case "v2" | "v2-l3": + darts_file = DARTS_V2_DIR / f"{gridname}_darts-{version}.zarr" + case "mllabels": + darts_file = DARTS_MLLABELS_DIR / f"{gridname}_darts-{version}.zarr" + case _: + raise ValueError(f"Unknown DARTS version: {version}") + return darts_file def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path: @@ -85,13 +89,20 @@ def get_era5_stores( agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily", grid: Grid | None = None, level: int | None = None, + temporal: Literal["synopsis"] | None = None, ) -> Path: - if grid is None or level is None: - (ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True) - return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr" + pdir = ERA5_DIR + if temporal is not None: + agg += f"_{temporal}" # ty:ignore[invalid-assignment] + fname = f"{agg}_climate.zarr" - gridname = _get_gridname(grid, level) - aligned_path = ERA5_DIR / f"{gridname}_{agg}_climate.zarr" + if grid is None or level is None: + pdir = pdir / "intermediate" + pdir.mkdir(parents=True, exist_ok=True) + else: + gridname = _get_gridname(grid, level) + fname = f"{gridname}_{fname}" + aligned_path = pdir / fname return aligned_path @@ -122,6 +133,12 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int return cache_file +def get_features_cache(ensemble_id: str, cells_hash: str) -> Path: + cache_dir = DATASET_ENSEMBLES_DIR / "cache" / "features" / ensemble_id + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / f"cells_{cells_hash}.parquet" + + def get_cv_results_dir( name: str, grid: Grid, @@ -133,3 +150,15 @@ def get_cv_results_dir( results_dir = RESULTS_DIR / f"{gridname}_{name}_cv{now}_{task}" results_dir.mkdir(parents=True, exist_ok=True) return results_dir + + +def get_autogluon_results_dir( + grid: Grid, + level: int, + task: Task, +) -> Path: + gridname = _get_gridname(grid, level) + now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + results_dir = RESULTS_DIR / f"{gridname}_autogluon_{now}_{task}" + results_dir.mkdir(parents=True, exist_ok=True) + return results_dir diff --git a/src/entropice/utils/types.py b/src/entropice/utils/types.py index 7c47d84..1c228b1 100644 --- a/src/entropice/utils/types.py +++ b/src/entropice/utils/types.py @@ -15,11 +15,14 @@ type GridLevel = Literal[ "healpix9", "healpix10", ] -type TargetDataset = Literal["darts_rts", "darts_mllabels"] +type TargetDataset = Literal["darts_v2", "darts_v1", "darts_mllabels", "darts_rts"] type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"] type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"] -type Task = Literal["binary", "count", "density"] +type Task = Literal["binary", "count_regimes", "density_regimes", "count", "density"] +# TODO: Consider implementing a "timeseries" temporal mode +type TemporalMode = Literal["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024] type Model = Literal["espa", "xgboost", "rf", "knn"] +type Stage = Literal["train", "inference", "visualization"] @dataclass(frozen=True) @@ -70,8 +73,9 @@ class GridConfig: # Note: get_args() doesn't work with Python 3.12+ type statement, so we define explicit lists -all_tasks: list[Task] = ["binary", "count", "density"] -all_target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"] +all_tasks: list[Task] = ["binary", "count_regimes", "density_regimes", "count", "density"] +all_temporal_modes: list[TemporalMode] = ["feature", "synopsis", 2018, 2019, 2020, 2021, 2022, 2023, 2024] +all_target_datasets: list[TargetDataset] = ["darts_mllabels", "darts_rts"] all_l2_source_datasets: list[L2SourceDataset] = [ "ArcticDEM", "ERA5-shoulder",