Add stopwatch to dataset functions
This commit is contained in:
parent
231caa62e7
commit
ad5f810f34
1 changed files with 11 additions and 6 deletions
|
|
@ -253,14 +253,17 @@ class DatasetEnsemble:
|
|||
|
||||
@cached_property
|
||||
def cell_ids(self) -> pd.Series:
|
||||
"""Series of all cell-ids of the grid."""
|
||||
return self.read_grid()["cell_id"]
|
||||
|
||||
@cached_property
|
||||
def geometries(self) -> pd.Series:
|
||||
"""Series of all geometries of the grid."""
|
||||
return self.read_grid()["geometry"]
|
||||
|
||||
# @stopwatch("Reading grid")
|
||||
@stopwatch("Reading grid")
|
||||
def read_grid(self) -> gpd.GeoDataFrame:
|
||||
"""Load the grid dataframe and enrich it with lat-lon information."""
|
||||
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
|
||||
|
||||
# Add the lat / lon of the cell centers
|
||||
|
|
@ -275,7 +278,7 @@ class DatasetEnsemble:
|
|||
grid_gdf = grid_gdf.set_index("cell_id")
|
||||
return grid_gdf
|
||||
|
||||
# @stopwatch.f("Getting target labels", print_kwargs=["stage"])
|
||||
@stopwatch.f("Getting target labels", print_kwargs=["task"])
|
||||
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
|
||||
"""Create a training target labels for a specific task.
|
||||
|
||||
|
|
@ -348,7 +351,7 @@ class DatasetEnsemble:
|
|||
}
|
||||
).set_index("cell_id")
|
||||
|
||||
# @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"])
|
||||
@stopwatch.f("Reading member", print_kwargs=["member", "lazy"])
|
||||
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
|
||||
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
|
||||
|
||||
|
|
@ -480,7 +483,7 @@ class DatasetEnsemble:
|
|||
dataset.to_parquet(cache_file)
|
||||
return dataset
|
||||
|
||||
# @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"])
|
||||
@stopwatch.f("Preparing ERA5", print_kwargs=["era5_agg"])
|
||||
def _prep_era5(
|
||||
self,
|
||||
cell_ids: pd.Series,
|
||||
|
|
@ -493,7 +496,7 @@ class DatasetEnsemble:
|
|||
era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||
return era5_df
|
||||
|
||||
# @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"])
|
||||
@stopwatch("Preparing AlphaEarth Embeddings")
|
||||
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
|
||||
embeddings_df = _collapse_to_dataframe(embeddings)
|
||||
|
|
@ -502,7 +505,7 @@ class DatasetEnsemble:
|
|||
embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||
return embeddings_df
|
||||
|
||||
# @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"])
|
||||
@stopwatch("Preparing ArcticDEM")
|
||||
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
|
||||
if len(arcticdem["cell_ids"]) == 0:
|
||||
|
|
@ -550,6 +553,7 @@ class DatasetEnsemble:
|
|||
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||
|
||||
@stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"])
|
||||
def create_training_df(
|
||||
self,
|
||||
task: Task,
|
||||
|
|
@ -581,6 +585,7 @@ class DatasetEnsemble:
|
|||
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||
return dataset
|
||||
|
||||
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"])
|
||||
def create_training_set(
|
||||
self,
|
||||
task: Task,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue