Fix dataset id
This commit is contained in:
parent
64d23a389d
commit
2be2244cdb
1 changed files with 5 additions and 29 deletions
|
|
@ -12,7 +12,8 @@ Naming conventions:
|
|||
"""
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
import geopandas as gpd
|
||||
|
|
@ -21,7 +22,6 @@ import seaborn as sns
|
|||
import xarray as xr
|
||||
from rich import pretty, traceback
|
||||
from sklearn import set_config
|
||||
from stopuhr import stopwatch
|
||||
|
||||
import entropice.paths
|
||||
|
||||
|
|
@ -33,7 +33,6 @@ set_config(array_api_dispatch=True)
|
|||
sns.set_theme("talk", "whitegrid")
|
||||
|
||||
|
||||
@stopwatch.f("Get ERA5 tempus", print_kwargs=["temporal"])
|
||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||
if temporal == "yearly":
|
||||
return df.index.get_level_values("time").year
|
||||
|
|
@ -69,27 +68,10 @@ class DatasetEnsemble:
|
|||
|
||||
def id(self):
|
||||
return hashlib.blake2b(
|
||||
str(
|
||||
(
|
||||
self.grid,
|
||||
self.level,
|
||||
self.target,
|
||||
frozenset(self.members),
|
||||
frozenset(
|
||||
(member, frozenset(dim_filters.items()))
|
||||
for member, dim_filters in self.dimension_filters.items()
|
||||
),
|
||||
frozenset(
|
||||
(member, frozenset(var_filters)) for member, var_filters in self.variable_filters.items()
|
||||
),
|
||||
self.filter_target,
|
||||
self.add_lonlat,
|
||||
)
|
||||
).encode("utf-8"),
|
||||
json.dumps(asdict(self), sort_keys=True).encode("utf-8"),
|
||||
digest_size=16,
|
||||
).hexdigest()
|
||||
|
||||
@stopwatch.f("Read member data", print_kwargs=["member", "lazy"])
|
||||
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||
if member == "AlphaEarth":
|
||||
store = entropice.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||
|
|
@ -133,7 +115,6 @@ class DatasetEnsemble:
|
|||
ds.load()
|
||||
return ds
|
||||
|
||||
@stopwatch("Loading targets")
|
||||
def _read_target(self) -> gpd.GeoDataFrame:
|
||||
if self.target == "darts_rts":
|
||||
target_store = entropice.paths.get_darts_rts_file(grid=self.grid, level=self.level)
|
||||
|
|
@ -157,7 +138,6 @@ class DatasetEnsemble:
|
|||
|
||||
return targets
|
||||
|
||||
@stopwatch.f("Prepare ERA5 data", print_kwargs=["temporal"])
|
||||
def _prep_era5(
|
||||
self, targets: gpd.GeoDataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]
|
||||
) -> pd.DataFrame:
|
||||
|
|
@ -172,7 +152,6 @@ class DatasetEnsemble:
|
|||
era5_df.columns = [f"era5_{var}_{t}_{agg}" for var, t, agg in era5_df.columns]
|
||||
return era5_df
|
||||
|
||||
@stopwatch("Prepare embeddings data")
|
||||
def _prep_embeddings(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
embeddings = self._read_member("AlphaEarth", targets)["embeddings"]
|
||||
embeddings_df = embeddings.to_dataframe(name="value")
|
||||
|
|
@ -180,7 +159,6 @@ class DatasetEnsemble:
|
|||
embeddings_df.columns = [f"embeddings_{agg}_{band}_{year}" for year, agg, band in embeddings_df.columns]
|
||||
return embeddings_df
|
||||
|
||||
@stopwatch("Prepare arcticdem data")
|
||||
def _prep_arcticdem(self, targets: gpd.GeoDataFrame) -> pd.DataFrame:
|
||||
arcticdem = self._read_member("ArcticDEM", targets)
|
||||
|
||||
|
|
@ -208,12 +186,11 @@ class DatasetEnsemble:
|
|||
n_cols += n_cols_member
|
||||
print(f"=== Total number of features in dataset: {n_cols}")
|
||||
|
||||
@stopwatch("Create dataset")
|
||||
def create(self, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||
cache_file = entropice.paths.get_dataset_cache(self.id())
|
||||
if cache_mode == "r" and cache_file.exists():
|
||||
dataset = pd.read_parquet(cache_file)
|
||||
dataset = gpd.read_parquet(cache_file)
|
||||
print(
|
||||
f"Loaded cached dataset from {cache_file} with {len(dataset)} samples"
|
||||
f" and {len(dataset.columns)} features."
|
||||
|
|
@ -232,7 +209,6 @@ class DatasetEnsemble:
|
|||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
|
||||
with stopwatch("Combine datasets"):
|
||||
dataset = targets.set_index("cell_id").join(member_dfs)
|
||||
print(f"Prepared dataset with {len(dataset)} samples and {len(dataset.columns)} features.")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue