Add cache and improve performance of dataset ensembles
This commit is contained in:
parent
67030c9f0d
commit
64d23a389d
2 changed files with 55 additions and 19 deletions
|
|
@ -11,6 +11,7 @@ Naming conventions:
|
|||
- Dimensions of L2 Datasets are e.g. time or aggregation
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
|
@ -32,6 +33,7 @@ set_config(array_api_dispatch=True)
|
|||
sns.set_theme("talk", "whitegrid")
|
||||
|
||||
|
||||
@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"])
|
||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||
if temporal == "yearly":
|
||||
return df.index.get_level_values("time").year
|
||||
|
|
@ -65,6 +67,29 @@ class DatasetEnsemble:
|
|||
filter_target: str | Literal[False] = False
|
||||
add_lonlat: bool = True
|
||||
|
||||
def id(self):
|
||||
return hashlib.blake2b(
|
||||
str(
|
||||
(
|
||||
self.grid,
|
||||
self.level,
|
||||
self.target,
|
||||
frozenset(self.members),
|
||||
frozenset(
|
||||
(member, frozenset(dim_filters.items()))
|
||||
for member, dim_filters in self.dimension_filters.items()
|
||||
),
|
||||
frozenset(
|
||||
(member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items()
|
||||
),
|
||||
self.filter_target,
|
||||
self.add_lonlat,
|
||||
)
|
||||
).encode("utf-8"),
|
||||
digest_size=16,
|
||||
).hexdigest()
|
||||
|
||||
@stopwatch.f("Read member data", print_kwargs=["member", "lazy"])
|
||||
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||
if member == "AlphaEarth":
|
||||
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||
|
|
@ -136,30 +161,20 @@ class DatasetEnsemble:
|
|||
def _prep_era5(
|
||||
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
||||
) -> pd.DataFrame:
|
||||
era5_df = []
|
||||
era5 = self._read_member(f"ERA5-{temporal}", targets)
|
||||
|
||||
for var in era5.data_vars:
|
||||
df = era5[var].to_dataframe()
|
||||
df["t"] = _get_era5_tempus(df, temporal)
|
||||
# If aggregations is not in dims, we can pivot directly
|
||||
era5_df = era5.to_dataframe()
|
||||
era5_df["t"] = _get_era5_tempus(era5_df, temporal)
|
||||
if "aggregations" not in era5.dims:
|
||||
df = (
|
||||
df.pivot_table(index="cell_ids", columns="t", values=var)
|
||||
.rename(columns=lambda x: f"era5_{var}_{x}")
|
||||
.rename_axis(None, axis=1)
|
||||
)
|
||||
era5_df = era5_df.pivot_table(index="cell_ids", columns="t")
|
||||
era5_df.columns = [f"era5_{var}_{t}" for var, t in era5_df.columns]
|
||||
else:
|
||||
df = df.pivot_table(index="cell_ids", columns=["t", "aggregations"], values=var)
|
||||
df.columns = [f"era5_{var}_{t}_{agg}" for t, agg in df.columns]
|
||||
era5_df.append(df)
|
||||
era5_df = pd.concat(era5_df, axis=1)
|
||||
era5_df = era5_df.pivot_table(index="cell_ids", columns=["t", "aggregations"])
|
||||
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
|
||||
return era5_df
|
||||
|
||||
@stopwatch("Prepare embeddings data")
|
||||
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
|
||||
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
embeddings_df = embeddings_df.pivot_table(index="cell_ids", columns=["year", "agg", "band"], values="value")
|
||||
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
|
|
@ -193,7 +208,17 @@ class DatasetEnsemble:
|
|||
n_cols += n_cols_member
|
||||
print(f"=== Total number of features in dataset: {n_cols}")
|
||||
|
||||
def create(self) -> pd.DataFrame:
|
||||
@stopwatch("Create dataset")
|
||||
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||
cache_file = entropice.paths.get_dataset_cache(self.id())
|
||||
if cache_mode == "r" and cache_file.exists():
|
||||
dataset = pd.read_parquet(cache_file)
|
||||
print(
|
||||
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
|
||||
f" and {len(dataset.columns)} features."
|
||||
)
|
||||
return dataset
|
||||
targets = self._read_target()
|
||||
|
||||
member_dfs = []
|
||||
|
|
@ -210,4 +235,8 @@ class DatasetEnsemble:
|
|||
with stopwatch("Combine datasets"):
|
||||
dataset = targets.set_index("cell_id").join(member_dfs)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
|
||||
if cache_mode in ["o", "r"]:
|
||||
dataset.to_parquet(cache_file)
|
||||
print(f"Saved dataset to cache at {cache_file}.")
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ ARCTICDEM_DIR = DATA_DIR / "arcticdem"
|
|||
EMBEDDINGS_DIR = DATA_DIR / "embeddings"
|
||||
WATERMASK_DIR = DATA_DIR / "watermask"
|
||||
TRAINING_DIR = DATA_DIR / "training"
|
||||
DATASET_ENSEMBLES_DIR = DATA_DIR / "dataset_ensembles"
|
||||
RESULTS_DIR = DATA_DIR / "results"
|
||||
|
||||
GRIDS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -28,6 +29,7 @@ EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
|||
WATERMASK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
||||
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
DATASET_ENSEMBLES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
watermask_file = WATERMASK_DIR / "simplified_water_polygons.shp"
|
||||
|
|
@ -105,6 +107,11 @@ def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
|||
return dataset_file
|
||||
|
||||
|
||||
def get_dataset_cache(eid: str) -> Path:
|
||||
cache_file = DATASET_ENSEMBLES_DIR / f"{eid}_dataset.parquet"
|
||||
return cache_file
|
||||
|
||||
|
||||
def get_cv_results_dir(
|
||||
name: str,
|
||||
grid: Literal["hex", "healpix"],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue