Fix dataset id

This commit is contained in:
Tobias Hölzer 2025-12-11 11:40:27 +01:00
parent 64d23a389d
commit 2be2244cdb

View file

@ -12,7 +12,8 @@ Naming conventions:
"""
import hashlib
from dataclasses import dataclass, field
import json
from dataclasses import asdict, dataclass, field
from typing import Literal
import geopandas as gpd
@ -21,7 +22,6 @@ import seaborn as sns
import xarray as xr
from rich import pretty, traceback
from sklearn import set_config
from stopuhr import stopwatch
import entropice.paths
@ -33,7 +33,6 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid")
@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"])
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
if temporal == "yearly":
return df.index.get_level_values("time").year
@ -69,27 +68,10 @@ class DatasetEnsemble:
def id(self):
return hashlib.blake2b(
str(
(
self.grid,
self.level,
self.target,
frozenset(self.members),
frozenset(
(member, frozenset(dim_filters.items()))
for member, dim_filters in self.dimension_filters.items()
),
frozenset(
(member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items()
),
self.filter_target,
self.add_lonlat,
)
).encode("utf-8"),
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
digest_size=16,
).hexdigest()
@stopwatch.f("Read member data", print_kwargs=["member", "lazy"])
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth":
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
@ -133,7 +115,6 @@ class DatasetEnsemble:
ds.load()
return ds
@stopwatch("Loading targets")
def _read_target(self) -> gpd.GeoDataFrame:
if self.target == "darts_rts":
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level)
@ -157,7 +138,6 @@ class DatasetEnsemble:
return targets
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
) -> pd.DataFrame:
@ -172,7 +152,6 @@ class DatasetEnsemble:
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
return era5_df
@stopwatch("Prepare embeddings data")
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
embeddings_df = embeddings.to_dataframe(name="value")
@ -180,7 +159,6 @@ class DatasetEnsemble:
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
return embeddings_df
@stopwatch("Prepare arcticdem data")
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
arcticdem = self._read_member("ArcticDEM", targets)
@ -208,12 +186,11 @@ class DatasetEnsemble:
n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}")
@stopwatch("Create dataset")
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.paths.get_dataset_cache(self.id())
if cache_mode == "r" and cache_file.exists():
dataset = pd.read_parquet(cache_file)
dataset = gpd.read_parquet(cache_file)
print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features."
@ -232,7 +209,6 @@ class DatasetEnsemble:
else:
raise NotImplementedError(f"Member {member} not implemented.")
with stopwatch("Combine datasets"):
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")