38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
import cyclopts
|
|
from rich import pretty, print, traceback
|
|
|
|
from entropice.ml.dataset import DatasetEnsemble
|
|
from entropice.utils.types import all_temporal_modes, grid_configs
|
|
|
|
pretty.install()
|
|
traceback.install()
|
|
|
|
cli = cyclopts.App()
|
|
|
|
|
|
def _gather_ensemble_stats(e: DatasetEnsemble):
|
|
# Get a small sample of the cell ids
|
|
sample_cell_ids = e.cell_ids[:5]
|
|
features = e.make_features(sample_cell_ids)
|
|
|
|
print(
|
|
f"[bold green]Ensemble Stats for Grid: {e.grid}, Level: {e.level}, Temporal Mode: {e.temporal_mode}[/bold green]"
|
|
)
|
|
print(f"Number of feature columns: {len(features.columns)}")
|
|
|
|
for member in ["arcticdem", "embeddings", "era5"]:
|
|
member_feature_cols = [col for col in features.columns if col.startswith(f"{member}_")]
|
|
print(f" - {member.capitalize()} feature columns: {len(member_feature_cols)}")
|
|
print()
|
|
|
|
|
|
@cli.default()
|
|
def validate_datasets():
|
|
for gc in grid_configs:
|
|
for temporal_mode in all_temporal_modes:
|
|
e = DatasetEnsemble(grid=gc.grid, level=gc.level, temporal_mode=temporal_mode)
|
|
_gather_ensemble_stats(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|