Fix dataset id

This commit is contained in:
Tobias Hölzer 2025-12-11 11:40:27 +01:00
parent 64d23a389d
commit 2be2244cdb

View file

@ -12,7 +12,8 @@ Naming conventions:
""" """
import hashlib import hashlib
from dataclasses import dataclass, field import json
from dataclasses import asdict, dataclass, field
from typing import Literal from typing import Literal
import geopandas as gpd import geopandas as gpd
@ -21,7 +22,6 @@ import seaborn as sns
import xarray as xr import xarray as xr
from rich import pretty, traceback from rich import pretty, traceback
from sklearn import set_config from sklearn import set_config
from stopuhr import stopwatch
import entropice.paths import entropice.paths
@ -33,7 +33,6 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid") sns.set_theme("talk", "whitegrid")
@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"])
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
if temporal == "yearly": if temporal == "yearly":
return df.index.get_level_values("time").year return df.index.get_level_values("time").year
@ -69,27 +68,10 @@ class DatasetEnsemble:
def id(self): def id(self):
return hashlib.blake2b( return hashlib.blake2b(
str( json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
(
self.grid,
self.level,
self.target,
frozenset(self.members),
frozenset(
(member, frozenset(dim_filters.items()))
for member, dim_filters in self.dimension_filters.items()
),
frozenset(
(member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items()
),
self.filter_target,
self.add_lonlat,
)
).encode("utf-8"),
digest_size=16, digest_size=16,
).hexdigest() ).hexdigest()
@stopwatch.f("Read member data", print_kwargs=["member", "lazy"])
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth": if member == "AlphaEarth":
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level) store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
@ -133,7 +115,6 @@ class DatasetEnsemble:
ds.load() ds.load()
return ds return ds
@stopwatch("Loading targets")
def _read_target(self) -> gpd.GeoDataFrame: def _read_target(self) -> gpd.GeoDataFrame:
if self.target == "darts_rts": if self.target == "darts_rts":
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level) target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level)
@ -157,7 +138,6 @@ class DatasetEnsemble:
return targets return targets
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
def _prep_era5( def _prep_era5(
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"] self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
) -> pd.DataFrame: ) -> pd.DataFrame:
@ -172,7 +152,6 @@ class DatasetEnsemble:
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns] era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
return era5_df return era5_df
@stopwatch("Prepare embeddings data")
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
embeddings = self._read_member("AlphaEarth", targets)["embeddings"] embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
embeddings_df = embeddings.to_dataframe(name="value") embeddings_df = embeddings.to_dataframe(name="value")
@ -180,7 +159,6 @@ class DatasetEnsemble:
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns] embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
return embeddings_df return embeddings_df
@stopwatch("Prepare arcticdem data")
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame: def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
arcticdem = self._read_member("ArcticDEM", targets) arcticdem = self._read_member("ArcticDEM", targets)
@ -208,12 +186,11 @@ class DatasetEnsemble:
n_cols += n_cols_member n_cols += n_cols_member
print(f"=== Total number of features in dataset: {n_cols}") print(f"=== Total number of features in dataset: {n_cols}")
@stopwatch("Create dataset")
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists # n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.paths.get_dataset_cache(self.id()) cache_file = entropice.paths.get_dataset_cache(self.id())
if cache_mode == "r" and cache_file.exists(): if cache_mode == "r" and cache_file.exists():
dataset = pd.read_parquet(cache_file) dataset = gpd.read_parquet(cache_file)
print( print(
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples" f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
f" and {len(dataset.columns)} features." f" and {len(dataset.columns)} features."
@ -232,8 +209,7 @@ class DatasetEnsemble:
else: else:
raise NotImplementedError(f"Member {member} not implemented.") raise NotImplementedError(f"Member {member} not implemented.")
with stopwatch("Combine datasets"): dataset = targets.set_index("cell_id").join(member_dfs)
dataset = targets.set_index("cell_id").join(member_dfs)
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.") print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
if cache_mode in ["o", "r"]: if cache_mode in ["o", "r"]: