From 64d23a389dfbda0fc1b19e5561e73ad8310cd973 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Wed, 10 Dec 2025 16:14:38 +0100 Subject: [PATCH] Add cache and improve performance of dataset ensembles --- src/entropice/dataset.py | 67 ++++++++++++++++++++++++++++------------ src/entropice/paths.py | 7 +++++ 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py index 75d695c..5c8aab5 100644 --- a/src/entropice/dataset.py +++ b/src/entropice/dataset.py @@ -11,6 +11,7 @@ Naming conventions: - Dimensions of L2 Datasets are e.g. time or aggregation """ +import hashlib from dataclasses import dataclass, field from typing import Literal @@ -32,6 +33,7 @@ set_config(array_api_dispatch=True) sns.set_theme("talk", "whitegrid") +@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"]) def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): if temporal == "yearly": return df.index.get_level_values("time").year @@ -65,6 +67,29 @@ class DatasetEnsemble: filter_target: str | Literal[False] = False add_lonlat: bool = True + def id(self): + return hashlib.blake2b( + str( + ( + self.grid, + self.level, + self.target, + frozenset(self.members), + frozenset( + (member, frozenset(dim_filters.items())) + for member, dim_filters in self.dimension_filters.items() + ), + frozenset( + (member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items() + ), + self.filter_target, + self.add_lonlat, + ) + ).encode("utf-8"), + digest_size=16, + ).hexdigest() + + @stopwatch.f("Read member data", print_kwargs=["member", "lazy"]) def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: if member == "AlphaEarth": store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level) @@ -136,30 +161,20 @@ class DatasetEnsemble: def _prep_era5( self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] ) -> pd.DataFrame: - era5_df = [] era5 = self._read_member(f"ERA5-{temporal}", targets) - - for var in era5.data_vars: - df = era5[var].to_dataframe() - df["t"] = _get_era5_tempus(df, temporal) - # If aggregations is not in dims, we can pivot directly - if "aggregations" not in era5.dims: - df = ( - df.pivot_table(index="cell_ids", columns="t", values=var) - .rename(columns=lambda x: f"era5_{var}_{x}") - .rename_axis(None, axis=1) - ) - else: - df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var) - df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns] - era5_df.append(df) - era5_df = pd.concat(era5_df, axis=1) + era5_df = era5.to_dataframe() + era5_df["t"] = _get_era5_tempus(era5_df, temporal) + if "aggregations" not in era5.dims: + era5_df = era5_df.pivot_table(index="cell_ids", columns="t") + era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns] + else: + era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"]) + era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns] return era5_df @stopwatch("Prepare embeddings data") def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: embeddings = self._read_member("AlphaEarth", targets)["embeddings"] - embeddings_df = embeddings.to_dataframe(name="value") embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value") embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] @@ -193,7 +208,17 @@ class DatasetEnsemble: n_cols += n_cols_member print(f"=== Total number of features in dataset: {n_cols}") - def create(self) -> pd.DataFrame: + @stopwatch("Create dataset") + def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: + # n: no cache, o: overwrite cache, r: read cache if exists + cache_file = entropice.paths.get_dataset_cache(self.id()) + if cache_mode == "r" and cache_file.exists(): + dataset = pd.read_parquet(cache_file) + print( + f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" + f" and {len(dataset.columns)} features." + ) + return dataset targets = self._read_target() member_dfs = [] @@ -210,4 +235,8 @@ class DatasetEnsemble: with stopwatch("Combine datasets"): dataset = targets.set_index("cell_id").join(member_dfs) print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") + + if cache_mode in ["o", "r"]: + dataset.to_parquet(cache_file) + print(f"Saved dataset to cache at {cache_file}.") return dataset diff --git a/src/entropice/paths.py b/src/entropice/paths.py index 5c5fa76..85dd7db 100644 --- a/src/entropice/paths.py +++ b/src/entropice/paths.py @@ -17,6 +17,7 @@ ARCTICDEM_DIR = DATA_DIR / "arcticdem" EMBEDDINGS_DIR = DATA_DIR / "embeddings" WATERMASK_DIR = DATA_DIR / "watermask" TRAINING_DIR = DATA_DIR / "training" +DATASET_ENSEMBLES_DIR = DATA_DIR / "dataset_ensembles" RESULTS_DIR = DATA_DIR / "results" GRIDS_DIR.mkdir(parents=True, exist_ok=True) @@ -28,6 +29,7 @@ EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) WATERMASK_DIR.mkdir(parents=True, exist_ok=True) TRAINING_DIR.mkdir(parents=True, exist_ok=True) RESULTS_DIR.mkdir(parents=True, exist_ok=True) +DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True) watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" @@ -105,6 +107,11 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: return dataset_file +def get_dataset_cache(eid: str) -> Path: + cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet" + return cache_file + + def get_cv_results_dir( name: str, grid: Literal["hex", "healpix"],