Add stopwatch to dataset functions

This commit is contained in:
Tobias Hölzer 2026-01-11 16:00:53 +01:00
parent 231caa62e7
commit ad5f810f34

View file

@ -253,14 +253,17 @@ class DatasetEnsemble:
@cached_property @cached_property
def cell_ids(self) -> pd.Series: def cell_ids(self) -> pd.Series:
"""Series of all cell-ids of the grid."""
return self.read_grid()["cell_id"] return self.read_grid()["cell_id"]
@cached_property @cached_property
def geometries(self) -> pd.Series: def geometries(self) -> pd.Series:
"""Series of all geometries of the grid."""
return self.read_grid()["geometry"] return self.read_grid()["geometry"]
# @stopwatch("Reading grid") @stopwatch("Reading grid")
def read_grid(self) -> gpd.GeoDataFrame: def read_grid(self) -> gpd.GeoDataFrame:
"""Load the grid dataframe and enrich it with lat-lon information."""
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level) grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
# Add the lat / lon of the cell centers # Add the lat / lon of the cell centers
@ -275,7 +278,7 @@ class DatasetEnsemble:
grid_gdf = grid_gdf.set_index("cell_id") grid_gdf = grid_gdf.set_index("cell_id")
return grid_gdf return grid_gdf
# @stopwatch.f("Getting target labels", print_kwargs=["stage"]) @stopwatch.f("Getting target labels", print_kwargs=["task"])
def get_targets(self, task: Task) -> gpd.GeoDataFrame: def get_targets(self, task: Task) -> gpd.GeoDataFrame:
"""Create a training target labels for a specific task. """Create a training target labels for a specific task.
@ -348,7 +351,7 @@ class DatasetEnsemble:
} }
).set_index("cell_id") ).set_index("cell_id")
# @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"]) @stopwatch.f("Reading member", print_kwargs=["member", "lazy"])
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901 def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration. """Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
@ -480,7 +483,7 @@ class DatasetEnsemble:
dataset.to_parquet(cache_file) dataset.to_parquet(cache_file)
return dataset return dataset
# @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"]) @stopwatch.f("Preparing ERA5", print_kwargs=["era5_agg"])
def _prep_era5( def _prep_era5(
self, self,
cell_ids: pd.Series, cell_ids: pd.Series,
@ -493,7 +496,7 @@ class DatasetEnsemble:
era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return era5_df return era5_df
# @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"]) @stopwatch("Preparing AlphaEarth Embeddings")
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame: def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"] embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
embeddings_df = _collapse_to_dataframe(embeddings) embeddings_df = _collapse_to_dataframe(embeddings)
@ -502,7 +505,7 @@ class DatasetEnsemble:
embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
return embeddings_df return embeddings_df
# @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"]) @stopwatch("Preparing ArcticDEM")
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame: def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True) arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
if len(arcticdem["cell_ids"]) == 0: if len(arcticdem["cell_ids"]) == 0:
@ -550,6 +553,7 @@ class DatasetEnsemble:
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size] batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode) yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
@stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"])
def create_training_df( def create_training_df(
self, self,
task: Task, task: Task,
@ -581,6 +585,7 @@ class DatasetEnsemble:
assert len(dataset) > 0, "No valid samples found after joining features and targets." assert len(dataset) > 0, "No valid samples found after joining features and targets."
return dataset return dataset
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"])
def create_training_set( def create_training_set(
self, self,
task: Task, task: Task,