Add cache and improve performance of dataset ensembles

This commit is contained in:
Tobias Hölzer 2025-12-10 16:14:38 +01:00
parent 67030c9f0d
commit 64d23a389d
2 changed files with 55 additions and 19 deletions

View file

@ -11,6 +11,7 @@ Naming conventions:
- Dimensions of L2 Datasets are e.g. time or aggregation
"""
import hashlib
from dataclasses import dataclass, field
from typing import Literal
@ -32,6 +33,7 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"])
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
if temporal == "yearly":
return df.index.get_level_values("time").year
@ -65,6 +67,29 @@ class DatasetEnsemble:
filter_target: str | Literal[False] = False
add_lonlat: bool = True
def id(self):
return hashlib.blake2b(
str(
(
self.grid,
self.level,
self.target,
frozenset(self.members),
frozenset(
(member, frozenset(dim_filters.items()))
for member, dim_filters in self.dimension_filters.items()
),
frozenset(
(member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items()
),
self.filter_target,
self.add_lonlat,
)
).encode("utf-8"),
digest_size=16,
).hexdigest()
@stopwatch.f("Read member data", print_kwargs=["member", "lazy"])
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth":
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
@ -136,30 +161,20 @@ class DatasetEnsemble:
def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
) -> pd.DataFrame:
era5_df = []
era5 = self._read_member(f"ERA5-{temporal}", targets)
for var in era5.data_vars:
df = era5[var].to_dataframe()
df["t"] = _get_era5_tempus(df, temporal)
# If aggregations is not in dims, we can pivot directly
era5_df = era5.to_dataframe()
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
if "aggregations" not in era5.dims:
df = (
df.pivot_table(index="cell_ids", columns="t", values=var)
.rename(columns=lambda x: f"era5_{var}_{x}")
.rename_axis(None, axis=1)
)
era5_df = era5_df.pivot_table(index="cell_ids", columns="t")
era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns]
else:
df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var)
df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns]
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
return era5_df
@stopwatch("Prepare embeddings data")
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
@ -193,7 +208,17 @@ class DatasetEnsemble:
n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}")
def create(self) -> pd.DataFrame:
@stopwatch("Create dataset")
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.paths.get_dataset_cache(self.id())
if cache_mode == "r" and cache_file.exists():
dataset = pd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
return dataset
targets = self._read_target()
member_dfs = []
@ -210,4 +235,8 @@ class DatasetEnsemble:
with stopwatch("Combine datasets"):
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["o", "r"]:
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
return dataset

View file

@ -17,6 +17,7 @@ ARCTICDEM_DIR = DATA_DIR / "arcticdem"
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
WATERMASK_DIR = DATA_DIR / "watermask"
TRAINING_DIR = DATA_DIR / "training"
DATASET_ENSEMBLES_DIR = DATA_DIR / "dataset_ensembles"
RESULTS_DIR = DATA_DIR / "results"
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
@ -28,6 +29,7 @@ EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
@ -105,6 +107,11 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file
def get_dataset_cache(eid: str) -> Path:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
return cache_file
def get_cv_results_dir(
name: str,
grid: Literal["hex", "healpix"],