From ad5f810f3483c665a6c218c519312a9766d297ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 11 Jan 2026 16:00:53 +0100 Subject: [PATCH] Add stopwatch to dataset functions --- src/entropice/ml/dataset.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/entropice/ml/dataset.py b/src/entropice/ml/dataset.py index 1fa3608..275319a 100644 --- a/src/entropice/ml/dataset.py +++ b/src/entropice/ml/dataset.py @@ -253,14 +253,17 @@ class DatasetEnsemble: @cached_property def cell_ids(self) -> pd.Series: + """Series of all cell-ids of the grid.""" return self.read_grid()["cell_id"] @cached_property def geometries(self) -> pd.Series: + """Series of all geometries of the grid.""" return self.read_grid()["geometry"] - # @stopwatch("Reading grid") + @stopwatch("Reading grid") def read_grid(self) -> gpd.GeoDataFrame: + """Load the grid dataframe and enrich it with lat-lon information.""" grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level) # Add the lat / lon of the cell centers @@ -275,7 +278,7 @@ class DatasetEnsemble: grid_gdf = grid_gdf.set_index("cell_id") return grid_gdf - # @stopwatch.f("Getting target labels", print_kwargs=["stage"]) + @stopwatch.f("Getting target labels", print_kwargs=["task"]) def get_targets(self, task: Task) -> gpd.GeoDataFrame: """Create a training target labels for a specific task. @@ -348,7 +351,7 @@ class DatasetEnsemble: } ).set_index("cell_id") - # @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"]) + @stopwatch.f("Reading member", print_kwargs=["member", "lazy"]) def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901 """Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration. @@ -480,7 +483,7 @@ class DatasetEnsemble: dataset.to_parquet(cache_file) return dataset - # @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"]) + @stopwatch.f("Preparing ERA5", print_kwargs=["era5_agg"]) def _prep_era5( self, cell_ids: pd.Series, @@ -493,7 +496,7 @@ class DatasetEnsemble: era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) return era5_df - # @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"]) + @stopwatch("Preparing AlphaEarth Embeddings") def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame: embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"] embeddings_df = _collapse_to_dataframe(embeddings) @@ -502,7 +505,7 @@ class DatasetEnsemble: embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan) return embeddings_df - # @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"]) + @stopwatch("Preparing ArcticDEM") def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame: arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True) if len(arcticdem["cell_ids"]) == 0: @@ -550,6 +553,7 @@ class DatasetEnsemble: batch_cell_ids = all_cell_ids.iloc[i : i + batch_size] yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode) + @stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"]) def create_training_df( self, task: Task, @@ -581,6 +585,7 @@ class DatasetEnsemble: assert len(dataset) > 0, "No valid samples found after joining features and targets." return dataset + @stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"]) def create_training_set( self, task: Task,