entropice/tests/validate_datasets.py

38 lines
1.1 KiB
Python

import cyclopts
from rich import pretty, print, traceback
from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import all_temporal_modes, grid_configs
pretty.install()
traceback.install()
cli = cyclopts.App()
def _gather_ensemble_stats(e: DatasetEnsemble):
# Get a small sample of the cell ids
sample_cell_ids = e.cell_ids[:5]
features = e.make_features(sample_cell_ids)
print(
f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]"
)
print(f"Number of feature columns: {len(features.columns)}")
for member in ["arcticdem", "embeddings", "era5"]:
member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")]
print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}")
print()
@cli.default()
def validate_datasets():
for gc in grid_configs:
for temporal_mode in all_temporal_modes:
e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode)
_gather_ensemble_stats(e)
if __name__ == "__main__":
cli()