Add cache and improve performance of dataset ensembles

This commit is contained in:
Tobias Hölzer 2025-12-10 16:14:38 +01:00
parent 67030c9f0d
commit 64d23a389d
2 changed files with 55 additions and 19 deletions

View file

@ -11,6 +11,7 @@ Naming conventions:
- Dimensions of L2 Datasets are e.g. time or aggregation - Dimensions of L2 Datasets are e.g. time or aggregation
""" """
import hashlib
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal from typing import Literal
@ -32,6 +33,7 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid") sns.set_theme("talk", "whitegrid")
@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"])
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
if temporal == "yearly": if temporal == "yearly":
return df.index.get_level_values("time").year return df.index.get_level_values("time").year
@ -65,6 +67,29 @@ class DatasetEnsemble:
filter_target: str | Literal[False] = False filter_target: str | Literal[False] = False
add_lonlat: bool = True add_lonlat: bool = True
def id(self):
return hashlib.blake2b(
str(
(
self.grid,
self.level,
self.target,
frozenset(self.members),
frozenset(
(member, frozenset(dim_filters.items()))
for member, dim_filters in self.dimension_filters.items()
),
frozenset(
(member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items()
),
self.filter_target,
self.add_lonlat,
)
).encode("utf-8"),
digest_size=16,
).hexdigest()
@stopwatch.f("Read member data", print_kwargs=["member", "lazy"])
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth": if member == "AlphaEarth":
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level) store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
@ -136,30 +161,20 @@ class DatasetEnsemble:
def _prep_era5( def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
) -> pd.DataFrame: ) -> pd.DataFrame:
era5_df = []
era5 = self._read_member(f"ERA5-{temporal}", targets) era5 = self._read_member(f"ERA5-{temporal}", targets)
era5_df = era5.to_dataframe()
for var in era5.data_vars: era5_df["t"] = _get_era5_tempus(era5_df, temporal)
df = era5[var].to_dataframe() if "aggregations" not in era5.dims:
df["t"] = _get_era5_tempus(df, temporal) era5_df = era5_df.pivot_table(index="cell_ids", columns="t")
# If aggregations is not in dims, we can pivot directly era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns]
if "aggregations" not in era5.dims: else:
df = ( era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
df.pivot_table(index="cell_ids", columns="t", values=var) era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
.rename(columns=lambda x: f"era5_{var}_{x}")
.rename_axis(None, axis=1)
)
else:
df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var)
df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns]
era5_df.append(df)
era5_df = pd.concat(era5_df, axis=1)
return era5_df return era5_df
@stopwatch("Prepare embeddings data") @stopwatch("Prepare embeddings data")
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"] embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
embeddings_df = embeddings.to_dataframe(name="value") embeddings_df = embeddings.to_dataframe(name="value")
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value") embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
@ -193,7 +208,17 @@ class DatasetEnsemble:
n_cols += n_cols_member n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}") print(f"=== Total number of features in dataset: {n_cols}")
def create(self) -> pd.DataFrame: @stopwatch("Create dataset")
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.paths.get_dataset_cache(self.id())
if cache_mode == "r" and cache_file.exists():
dataset = pd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
)
return dataset
targets = self._read_target() targets = self._read_target()
member_dfs = [] member_dfs = []
@ -210,4 +235,8 @@ class DatasetEnsemble:
with stopwatch("Combine datasets"): with stopwatch("Combine datasets"):
dataset = targets.set_index("cell_id").join(member_dfs) dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["o", "r"]:
dataset.to_parquet(cache_file)
print(f"Saved dataset to cache at {cache_file}.")
return dataset return dataset

View file

@ -17,6 +17,7 @@ ARCTICDEM_DIR = DATA_DIR / "arcticdem"
EMBEDDINGS_DIR = DATA_DIR / "embeddings" EMBEDDINGS_DIR = DATA_DIR / "embeddings"
WATERMASK_DIR = DATA_DIR / "watermask" WATERMASK_DIR = DATA_DIR / "watermask"
TRAINING_DIR = DATA_DIR / "training" TRAINING_DIR = DATA_DIR / "training"
DATASET_ENSEMBLES_DIR = DATA_DIR / "dataset_ensembles"
RESULTS_DIR = DATA_DIR / "results" RESULTS_DIR = DATA_DIR / "results"
GRIDS_DIR.mkdir(parents=True, exist_ok=True) GRIDS_DIR.mkdir(parents=True, exist_ok=True)
@ -28,6 +29,7 @@ EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
WATERMASK_DIR.mkdir(parents=True, exist_ok=True) WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
TRAINING_DIR.mkdir(parents=True, exist_ok=True) TRAINING_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True) RESULTS_DIR.mkdir(parents=True, exist_ok=True)
DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp" watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
@ -105,6 +107,11 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
return dataset_file return dataset_file
def get_dataset_cache(eid: str) -> Path:
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
return cache_file
def get_cv_results_dir( def get_cv_results_dir(
name: str, name: str,
grid: Literal["hex", "healpix"], grid: Literal["hex", "healpix"],