Add stopwatch to dataset functions
This commit is contained in:
parent
231caa62e7
commit
ad5f810f34
1 changed files with 11 additions and 6 deletions
|
|
@ -253,14 +253,17 @@ class DatasetEnsemble:
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def cell_ids(self) -> pd.Series:
|
def cell_ids(self) -> pd.Series:
|
||||||
|
"""Series of all cell-ids of the grid."""
|
||||||
return self.read_grid()["cell_id"]
|
return self.read_grid()["cell_id"]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def geometries(self) -> pd.Series:
|
def geometries(self) -> pd.Series:
|
||||||
|
"""Series of all geometries of the grid."""
|
||||||
return self.read_grid()["geometry"]
|
return self.read_grid()["geometry"]
|
||||||
|
|
||||||
# @stopwatch("Reading grid")
|
@stopwatch("Reading grid")
|
||||||
def read_grid(self) -> gpd.GeoDataFrame:
|
def read_grid(self) -> gpd.GeoDataFrame:
|
||||||
|
"""Load the grid dataframe and enrich it with lat-lon information."""
|
||||||
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
|
grid_gdf = entropice.spatial.grids.open(grid=self.grid, level=self.level)
|
||||||
|
|
||||||
# Add the lat / lon of the cell centers
|
# Add the lat / lon of the cell centers
|
||||||
|
|
@ -275,7 +278,7 @@ class DatasetEnsemble:
|
||||||
grid_gdf = grid_gdf.set_index("cell_id")
|
grid_gdf = grid_gdf.set_index("cell_id")
|
||||||
return grid_gdf
|
return grid_gdf
|
||||||
|
|
||||||
# @stopwatch.f("Getting target labels", print_kwargs=["stage"])
|
@stopwatch.f("Getting target labels", print_kwargs=["task"])
|
||||||
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
|
def get_targets(self, task: Task) -> gpd.GeoDataFrame:
|
||||||
"""Create a training target labels for a specific task.
|
"""Create a training target labels for a specific task.
|
||||||
|
|
||||||
|
|
@ -348,7 +351,7 @@ class DatasetEnsemble:
|
||||||
}
|
}
|
||||||
).set_index("cell_id")
|
).set_index("cell_id")
|
||||||
|
|
||||||
# @stopwatch.f("Reading member", print_kwargs=["member", "stage", "lazy"])
|
@stopwatch.f("Reading member", print_kwargs=["member", "lazy"])
|
||||||
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
|
def read_member(self, member: L2SourceDataset, cell_ids: pd.Series | None = None, lazy: bool = False) -> xr.Dataset: # noqa: C901
|
||||||
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
|
"""Read a single member (source) of the Ensemble and applies filters based on the ensemble configuration.
|
||||||
|
|
||||||
|
|
@ -480,7 +483,7 @@ class DatasetEnsemble:
|
||||||
dataset.to_parquet(cache_file)
|
dataset.to_parquet(cache_file)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
# @stopwatch.f("Preparing ERA5", print_kwargs=["stage", "temporal"])
|
@stopwatch.f("Preparing ERA5", print_kwargs=["era5_agg"])
|
||||||
def _prep_era5(
|
def _prep_era5(
|
||||||
self,
|
self,
|
||||||
cell_ids: pd.Series,
|
cell_ids: pd.Series,
|
||||||
|
|
@ -493,7 +496,7 @@ class DatasetEnsemble:
|
||||||
era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
era5_df = era5_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||||
return era5_df
|
return era5_df
|
||||||
|
|
||||||
# @stopwatch.f("Preparing ALphaEarth Embeddings", print_kwargs=["stage"])
|
@stopwatch("Preparing AlphaEarth Embeddings")
|
||||||
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
def _prep_embeddings(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||||
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
|
embeddings = self.read_member("AlphaEarth", cell_ids=cell_ids, lazy=False)["embeddings"]
|
||||||
embeddings_df = _collapse_to_dataframe(embeddings)
|
embeddings_df = _collapse_to_dataframe(embeddings)
|
||||||
|
|
@ -502,7 +505,7 @@ class DatasetEnsemble:
|
||||||
embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
embeddings_df = embeddings_df.reindex(cell_ids.to_numpy(), fill_value=np.nan)
|
||||||
return embeddings_df
|
return embeddings_df
|
||||||
|
|
||||||
# @stopwatch.f("Preparing ArcticDEM", print_kwargs=["stage"])
|
@stopwatch("Preparing ArcticDEM")
|
||||||
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
def _prep_arcticdem(self, cell_ids: pd.Series) -> pd.DataFrame:
|
||||||
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
|
arcticdem = self.read_member("ArcticDEM", cell_ids=cell_ids, lazy=True)
|
||||||
if len(arcticdem["cell_ids"]) == 0:
|
if len(arcticdem["cell_ids"]) == 0:
|
||||||
|
|
@ -550,6 +553,7 @@ class DatasetEnsemble:
|
||||||
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
batch_cell_ids = all_cell_ids.iloc[i : i + batch_size]
|
||||||
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
yield self.make_features(cell_ids=batch_cell_ids, cache_mode=cache_mode)
|
||||||
|
|
||||||
|
@stopwatch.f("Creating training DataFrame", print_kwargs=["task", "cache_mode"])
|
||||||
def create_training_df(
|
def create_training_df(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
|
|
@ -581,6 +585,7 @@ class DatasetEnsemble:
|
||||||
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
assert len(dataset) > 0, "No valid samples found after joining features and targets."
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@stopwatch.f("Creating training Dataset", print_kwargs=["task", "device", "cache_mode"])
|
||||||
def create_training_set(
|
def create_training_set(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue