From 2be2244cdb8f3fd478420fec85e5711343b73c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 11 Dec 2025 11:40:27 +0100 Subject: [PATCH] Fix dataset id --- src/entropice/dataset.py | 34 +++++----------------------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/src/entropice/dataset.py b/src/entropice/dataset.py index 5c8aab5..dfcd36c 100644 --- a/src/entropice/dataset.py +++ b/src/entropice/dataset.py @@ -12,7 +12,8 @@ Naming conventions: """ import hashlib -from dataclasses import dataclass, field +import json +from dataclasses import asdict, dataclass, field from typing import Literal import geopandas as gpd @@ -21,7 +22,6 @@ import seaborn as sns import xarray as xr from rich import pretty, traceback from sklearn import set_config -from stopuhr import stopwatch import entropice.paths @@ -33,7 +33,6 @@ set_config(array_api_dispatch=True) sns.set_theme("talk", "whitegrid") -@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"]) def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): if temporal == "yearly": return df.index.get_level_values("time").year @@ -69,27 +68,10 @@ class DatasetEnsemble: def id(self): return hashlib.blake2b( - str( - ( - self.grid, - self.level, - self.target, - frozenset(self.members), - frozenset( - (member, frozenset(dim_filters.items())) - for member, dim_filters in self.dimension_filters.items() - ), - frozenset( - (member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items() - ), - self.filter_target, - self.add_lonlat, - ) - ).encode("utf-8"), + json.dumps(asdict(self), sort_keys=True).encode("utf-8"), digest_size=16, ).hexdigest() - @stopwatch.f("Read member data", print_kwargs=["member", "lazy"]) def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: if member == "AlphaEarth": store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level) @@ -133,7 +115,6 @@ class DatasetEnsemble: ds.load() return ds - @stopwatch("Loading targets") def _read_target(self) -> gpd.GeoDataFrame: if self.target == "darts_rts": target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level) @@ -157,7 +138,6 @@ class DatasetEnsemble: return targets - @stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"]) def _prep_era5( self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] ) -> pd.DataFrame: @@ -172,7 +152,6 @@ class DatasetEnsemble: era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns] return era5_df - @stopwatch("Prepare embeddings data") def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: embeddings = self._read_member("AlphaEarth", targets)["embeddings"] embeddings_df = embeddings.to_dataframe(name="value") @@ -180,7 +159,6 @@ class DatasetEnsemble: embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] return embeddings_df - @stopwatch("Prepare arcticdem data") def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: arcticdem = self._read_member("ArcticDEM", targets) @@ -208,12 +186,11 @@ class DatasetEnsemble: n_cols += n_cols_member print(f"=== Total number of features in dataset: {n_cols}") - @stopwatch("Create dataset") def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: # n: no cache, o: overwrite cache, r: read cache if exists cache_file = entropice.paths.get_dataset_cache(self.id()) if cache_mode == "r" and cache_file.exists(): - dataset = pd.read_parquet(cache_file) + dataset = gpd.read_parquet(cache_file) print( f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" f" and {len(dataset.columns)} features." @@ -232,8 +209,7 @@ class DatasetEnsemble: else: raise NotImplementedError(f"Member {member} not implemented.") - with stopwatch("Combine datasets"): - dataset = targets.set_index("cell_id").join(member_dfs) + dataset = targets.set_index("cell_id").join(member_dfs) print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") if cache_mode in ["o", "r"]: