Cleanup dashboard utils and move pages to views
This commit is contained in:
parent
495ddc13f9
commit
36f8737075
26 changed files with 904 additions and 514 deletions
|
|
@ -11,11 +11,11 @@ Pages:
|
|||
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.inference_page import render_inference_page
|
||||
from entropice.dashboard.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.overview_page import render_overview_page
|
||||
from entropice.dashboard.training_analysis_page import render_training_analysis_page
|
||||
from entropice.dashboard.training_data_page import render_training_data_page
|
||||
from entropice.dashboard.views.inference_page import render_inference_page
|
||||
from entropice.dashboard.views.model_state_page import render_model_state_page
|
||||
from entropice.dashboard.views.overview_page import render_overview_page
|
||||
from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
|
||||
from entropice.dashboard.views.training_data_page import render_training_data_page
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import pydeck as pdk
|
|||
import streamlit as st
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_cmap, get_palette
|
||||
from entropice.dashboard.utils.colors import get_cmap, get_palette
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pydeck as pdk
|
|||
import streamlit as st
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_palette
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.data import TrainingResult
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import streamlit as st
|
|||
import xarray as xr
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_cmap
|
||||
from entropice.dashboard.utils.colors import get_cmap
|
||||
|
||||
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pydeck as pdk
|
|||
import streamlit as st
|
||||
from shapely.geometry import shape
|
||||
|
||||
from entropice.dashboard.plots.colors import get_palette
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.ml.dataset import CategoricalTrainingDataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m
|
|||
cmap = plt.get_cmap("viridis")
|
||||
# Sample colors evenly across the colormap
|
||||
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
|
||||
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices]
|
||||
base_colors = [mcolors.rgb2hex(cmap(idx).tolist()[:3]) for idx in indices]
|
||||
|
||||
# Create matplotlib colormap (for ultraplot and geopandas)
|
||||
matplotlib_cmap = mcolors.ListedColormap(base_colors)
|
||||
96
src/entropice/dashboard/utils/formatters.py
Normal file
96
src/entropice/dashboard/utils/formatters.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Formatters for dashboard display."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from entropice.utils.types import Grid, Model, Task
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDisplayInfo:
|
||||
keyword: Model
|
||||
short: str
|
||||
long: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.short} ({self.long})"
|
||||
|
||||
|
||||
model_display_infos: dict[Model, ModelDisplayInfo] = {
|
||||
"espa": ModelDisplayInfo(
|
||||
keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm"
|
||||
),
|
||||
"xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"),
|
||||
"rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"),
|
||||
"knn": ModelDisplayInfo(keyword="knn", short="k-NN", long="k-Nearest Neighbors Classifier"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskDisplayInfo:
|
||||
keyword: Task
|
||||
display_name: str
|
||||
explanation: str
|
||||
|
||||
|
||||
task_display_infos: dict[Task, TaskDisplayInfo] = {
|
||||
"binary": TaskDisplayInfo(
|
||||
keyword="binary",
|
||||
display_name="Binary Classification",
|
||||
explanation="Classify each grid cell as containing or not containing the target feature.",
|
||||
),
|
||||
"count": TaskDisplayInfo(
|
||||
keyword="count",
|
||||
display_name="Count Prediction",
|
||||
explanation="Predict the number of target features present within each grid cell.",
|
||||
),
|
||||
"density": TaskDisplayInfo(
|
||||
keyword="density",
|
||||
display_name="Density Estimation",
|
||||
explanation="Estimate the density of target features within each grid cell.",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingResultDisplayInfo:
|
||||
task: Task
|
||||
model: Model
|
||||
grid: Grid
|
||||
level: int
|
||||
timestamp: datetime
|
||||
|
||||
def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str:
|
||||
task = self.task.capitalize()
|
||||
model = self.model.upper()
|
||||
grid = self.grid.capitalize()
|
||||
level = self.level
|
||||
timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
if format_type == "model_first":
|
||||
return f"{model} - {task} - {grid}-{level} ({timestamp})"
|
||||
else: # task_first
|
||||
return f"{task} - {model} - {grid}-{level} ({timestamp})"
|
||||
|
||||
|
||||
def format_metric_name(metric: str) -> str:
|
||||
"""Format metric name for display.
|
||||
|
||||
Args:
|
||||
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
|
||||
|
||||
Returns:
|
||||
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
|
||||
|
||||
"""
|
||||
# Split by underscore and capitalize each part
|
||||
parts = metric.split("_")
|
||||
# Special handling for F1
|
||||
formatted_parts = []
|
||||
for part in parts:
|
||||
if part.lower() == "f1":
|
||||
formatted_parts.append("F1")
|
||||
else:
|
||||
formatted_parts.append(part.capitalize())
|
||||
return " ".join(formatted_parts)
|
||||
163
src/entropice/dashboard/utils/loaders.py
Normal file
163
src/entropice/dashboard/utils/loaders.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Data utilities for Entropice dashboard."""
|
||||
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import antimeridian
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import toml
|
||||
import xarray as xr
|
||||
from shapely.geometry import shape
|
||||
|
||||
import entropice.spatial.grids
|
||||
import entropice.utils.paths
|
||||
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
|
||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
||||
from entropice.ml.training import TrainingSettings
|
||||
from entropice.utils.types import L2SourceDataset, Task
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
"""Fix hexagon geometry crossing the antimeridian."""
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except ValueError as e:
|
||||
st.error(f"Error fixing geometry: {e}")
|
||||
return geom
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingResult:
|
||||
path: Path
|
||||
settings: TrainingSettings
|
||||
results: pd.DataFrame
|
||||
created_at: float
|
||||
available_metrics: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||
"""Load a TrainingResult from a given result directory path."""
|
||||
result_file = result_path / "search_results.parquet"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||
raise FileNotFoundError(f"Missing required files in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
|
||||
results = pd.read_parquet(result_file)
|
||||
|
||||
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
|
||||
|
||||
return cls(
|
||||
path=result_path,
|
||||
settings=settings,
|
||||
results=results,
|
||||
created_at=created_at,
|
||||
available_metrics=available_metrics,
|
||||
)
|
||||
|
||||
@property
|
||||
def display_info(self) -> TrainingResultDisplayInfo:
|
||||
return TrainingResultDisplayInfo(
|
||||
task=self.settings.task,
|
||||
model=self.settings.model,
|
||||
grid=self.settings.grid,
|
||||
level=self.settings.level,
|
||||
timestamp=datetime.fromtimestamp(self.created_at),
|
||||
)
|
||||
|
||||
def load_best_model(self) -> object | None:
|
||||
"""Load the best model from a training result."""
|
||||
model_file = self.path / "best_estimator_model.pkl"
|
||||
if not model_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(model_file, "rb") as f:
|
||||
model = pickle.load(f)
|
||||
return model
|
||||
except Exception as e:
|
||||
st.error(f"Error loading model: {e}")
|
||||
return None
|
||||
|
||||
def load_model_state(self) -> xr.Dataset | None:
|
||||
"""Load the model state from a training result."""
|
||||
state_file = self.path / "best_estimator_state.nc"
|
||||
if not state_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||
return state
|
||||
except Exception as e:
|
||||
st.error(f"Error loading model state: {e}")
|
||||
return None
|
||||
|
||||
def load_predictions(self) -> pd.DataFrame | None:
|
||||
"""Load predictions from a training result."""
|
||||
preds_file = self.path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
preds = pd.read_parquet(preds_file)
|
||||
return preds
|
||||
except Exception as e:
|
||||
st.error(f"Error loading predictions: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_training_results() -> list[TrainingResult]:
|
||||
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||
training_results: list[TrainingResult] = []
|
||||
for result_path in results_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
training_results.append(training_result)
|
||||
|
||||
# Sort by creation time (most recent first)
|
||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||
return training_results
|
||||
|
||||
|
||||
def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]:
|
||||
"""Load training data for all three tasks.
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
|
||||
Returns:
|
||||
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||
|
||||
"""
|
||||
return {
|
||||
"binary": e.create_cat_training_dataset("binary", device="cpu"),
|
||||
"count": e.create_cat_training_dataset("count", device="cpu"),
|
||||
"density": e.create_cat_training_dataset("density", device="cpu"),
|
||||
}
|
||||
|
||||
|
||||
def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
|
||||
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
||||
|
||||
Returns:
|
||||
xarray.Dataset with the raw data for the specified source.
|
||||
|
||||
"""
|
||||
targets = e._read_target()
|
||||
|
||||
# Load the member data lazily to get metadata
|
||||
ds = e._read_member(source, targets, lazy=False)
|
||||
|
||||
return ds, targets
|
||||
482
src/entropice/dashboard/utils/stats.py
Normal file
482
src/entropice/dashboard/utils/stats.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
"""General Statistics shared across multiple dashboard pages.
|
||||
|
||||
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal, cast, get_args
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
from stopuhr import stopwatch
|
||||
|
||||
import entropice.spatial.grids
|
||||
import entropice.utils.paths
|
||||
from entropice.dashboard.utils.loaders import TrainingResult
|
||||
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
|
||||
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemberStatistics:
|
||||
"""Statistics for a specific dataset member."""
|
||||
|
||||
feature_count: int # Number of features from this member
|
||||
variable_names: list[str] # Names of variables from this member
|
||||
dimensions: dict[str, int] # Dimension name to size mapping
|
||||
coordinates: list[str] # Names of coordinates from this member
|
||||
size_bytes: int # Size of this member's data on disk in bytes
|
||||
|
||||
@classmethod
|
||||
def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
|
||||
if member == "AlphaEarth":
|
||||
store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
|
||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||
era5_agg = member.split("-")[1]
|
||||
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
|
||||
elif member == "ArcticDEM":
|
||||
store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
|
||||
size_bytes = store.stat().st_size
|
||||
ds = xr.open_zarr(store, consolidated=False)
|
||||
|
||||
# Delete all coordinates which are not in the dimension
|
||||
for coord in ds.coords:
|
||||
if coord not in ds.dims:
|
||||
ds = ds.drop_vars(coord)
|
||||
n_cols_member = len(ds.data_vars)
|
||||
for dim in ds.sizes:
|
||||
if dim != "cell_ids":
|
||||
n_cols_member *= ds.sizes[dim]
|
||||
|
||||
return cls(
|
||||
feature_count=n_cols_member,
|
||||
variable_names=list(ds.data_vars),
|
||||
dimensions=dict(ds.sizes),
|
||||
coordinates=list(ds.coords),
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TargetStatistics:
|
||||
"""Statistics for a specific target dataset."""
|
||||
|
||||
training_cells: dict[Task, int] # Number of cells used for training
|
||||
coverage: dict[Task, float] # Percentage of total cells covered by training data
|
||||
class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
|
||||
class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
|
||||
size_bytes: int # Size of the target dataset on disk in bytes
|
||||
|
||||
@classmethod
|
||||
def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
|
||||
if target == "darts_rts":
|
||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
|
||||
elif target == "darts_mllabels":
|
||||
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
|
||||
else:
|
||||
raise NotImplementedError(f"Target {target} not implemented.")
|
||||
target_gdf = gpd.read_parquet(target_store)
|
||||
size_bytes = target_store.stat().st_size
|
||||
training_cells: dict[Task, int] = {}
|
||||
training_coverage: dict[Task, float] = {}
|
||||
class_counts: dict[Task, dict[str, int]] = {}
|
||||
class_distribution: dict[Task, dict[str, float]] = {}
|
||||
tasks = cast(list[Task], get_args(Task))
|
||||
for task in tasks:
|
||||
task_col = taskcol[task][target]
|
||||
cov_col = covcol[target]
|
||||
|
||||
task_gdf = target_gdf[target_gdf[cov_col]]
|
||||
training_cells[task] = len(task_gdf)
|
||||
training_coverage[task] = len(task_gdf) / total_cells * 100
|
||||
|
||||
model_labels = task_gdf[task_col].dropna()
|
||||
if task == "binary":
|
||||
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
|
||||
elif task == "count":
|
||||
binned = bin_values(model_labels.astype(int), task=task)
|
||||
elif task == "density":
|
||||
binned = bin_values(model_labels, task=task)
|
||||
else:
|
||||
raise ValueError("Invalid task.")
|
||||
counts = binned.value_counts()
|
||||
distribution = counts / counts.sum() * 100
|
||||
class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
|
||||
class_distribution[task] = distribution.to_dict()
|
||||
return TargetStatistics(
|
||||
training_cells=training_cells,
|
||||
coverage=training_coverage,
|
||||
class_counts=class_counts,
|
||||
class_distribution=class_distribution,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatasetStatistics:
|
||||
"""Statistics for a potential dataset at a specific grid and level.
|
||||
|
||||
These statistics are meant to be easily compute without loading a full dataset into memory.
|
||||
Further, it is meant to give an overview of all potential potential dataset compositions for a given grid and level.
|
||||
"""
|
||||
|
||||
total_features: int # Total number of features available in the dataset
|
||||
total_cells: int # Total number of grid cells potentially covered
|
||||
size_bytes: int # Size of the dataset on disk in bytes
|
||||
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
|
||||
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
|
||||
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
|
||||
grid_levels: set[tuple[Grid, int]] = {
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
("hex", 6),
|
||||
("healpix", 6),
|
||||
("healpix", 7),
|
||||
("healpix", 8),
|
||||
("healpix", 9),
|
||||
("healpix", 10),
|
||||
}
|
||||
for grid, level in grid_levels:
|
||||
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
|
||||
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
|
||||
total_cells = len(grid_gdf)
|
||||
assert total_cells > 0, "Grid must contain at least one cell."
|
||||
target_statistics: dict[TargetDataset, TargetStatistics] = {}
|
||||
targets = cast(list[TargetDataset], get_args(TargetDataset))
|
||||
for target in targets:
|
||||
target_statistics[target] = TargetStatistics.compute(
|
||||
grid=grid, level=level, target=target, total_cells=total_cells
|
||||
)
|
||||
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
|
||||
members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
|
||||
for member in members:
|
||||
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
|
||||
|
||||
total_features = sum(ms.feature_count for ms in member_statistics.values())
|
||||
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
|
||||
ts.size_bytes for ts in target_statistics.values()
|
||||
)
|
||||
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
|
||||
dataset_stats[grid_level] = DatasetStatistics(
|
||||
total_features=total_features,
|
||||
total_cells=total_cells,
|
||||
size_bytes=total_size_bytes,
|
||||
members=member_statistics,
|
||||
target=target_statistics,
|
||||
)
|
||||
|
||||
return dataset_stats
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnsembleMemberStatistics:
|
||||
n_features: int # Number of features from this member in the ensemble
|
||||
p_features: float # Percentage of features from this member in the ensemble
|
||||
n_nanrows: int # Number of rows which contain any NaN
|
||||
size_bytes: int # Size of this member's data in the ensemble in bytes
|
||||
|
||||
@classmethod
|
||||
def compute(
|
||||
cls,
|
||||
dataset: gpd.GeoDataFrame,
|
||||
member: L2SourceDataset,
|
||||
n_features: int,
|
||||
) -> "EnsembleMemberStatistics":
|
||||
if member == "AlphaEarth":
|
||||
member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
|
||||
elif member == "ArcticDEM":
|
||||
member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
|
||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||
era5_cols = dataset.columns.str.startswith("era5_")
|
||||
cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
|
||||
cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
|
||||
cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
|
||||
if member == "ERA5-yearly":
|
||||
member_dataset = dataset[cols_with_three_splits]
|
||||
elif member == "ERA5-seasonal":
|
||||
member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
|
||||
elif member == "ERA5-shoulder":
|
||||
member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
|
||||
else:
|
||||
raise NotImplementedError(f"Member {member} not implemented.")
|
||||
size_bytes_member = member_dataset.memory_usage(deep=True).sum()
|
||||
n_features_member = len(member_dataset.columns)
|
||||
p_features_member = n_features_member / n_features
|
||||
n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
|
||||
return EnsembleMemberStatistics(
|
||||
n_features=n_features_member,
|
||||
p_features=p_features_member,
|
||||
n_nanrows=n_rows_with_nan_member,
|
||||
size_bytes=size_bytes_member,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnsembleDatasetStatistics:
|
||||
"""Statistics for a specified composition / ensemble at a specific grid and level.
|
||||
|
||||
These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
|
||||
That way, the real number of features and cells after NaN filtering can be reported.
|
||||
"""
|
||||
|
||||
n_features: int # Number of features in the dataset
|
||||
n_cells: int # Number of grid cells covered
|
||||
n_nanrows: int # Number of rows which contain any NaN
|
||||
size_bytes: int # Size of the dataset on disk in bytes
|
||||
members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
|
||||
|
||||
@classmethod
|
||||
def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
|
||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||
# Assert that no column in all-nan
|
||||
assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
|
||||
|
||||
size_bytes = dataset.memory_usage(deep=True).sum()
|
||||
n_features = len(dataset.columns)
|
||||
n_cells = len(dataset)
|
||||
|
||||
# Number of rows which contain any NaN
|
||||
n_rows_with_nan = dataset.isna().any(axis=1).sum()
|
||||
|
||||
member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
|
||||
for member in ensemble.members:
|
||||
member_statistics[member] = EnsembleMemberStatistics.compute(
|
||||
dataset=dataset,
|
||||
member=member,
|
||||
n_features=n_features,
|
||||
)
|
||||
return cls(
|
||||
n_features=n_features,
|
||||
n_cells=n_cells,
|
||||
n_nanrows=n_rows_with_nan,
|
||||
size_bytes=size_bytes,
|
||||
members=member_statistics,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingDatasetStatistics:
|
||||
n_samples: int # Total number of samples in the dataset
|
||||
n_features: int # Number of features
|
||||
feature_names: list[str] # Names of all features
|
||||
n_training_samples: int # Number of cells used for training
|
||||
n_test_samples: int # Number of cells used for testing
|
||||
train_test_ratio: float # Ratio of training to test samples
|
||||
|
||||
class_labels: list[str] # Ordered list of class labels
|
||||
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
|
||||
n_classes: int # Number of classes
|
||||
|
||||
training_class_counts: dict[str, int] # Class counts in training set
|
||||
training_class_distribution: dict[str, float] # Class percentages in training set
|
||||
test_class_counts: dict[str, int] # Class counts in test set
|
||||
test_class_distribution: dict[str, float] # Class percentages in test set
|
||||
|
||||
raw_value_min: float # Minimum raw target value
|
||||
raw_value_max: float # Maximum raw target value
|
||||
raw_value_mean: float # Mean raw target value
|
||||
raw_value_median: float # Median raw target value
|
||||
raw_value_std: float # Standard deviation of raw target values
|
||||
|
||||
imbalance_ratio: float # Smallest class count / largest class count (overall)
|
||||
size_bytes: int # Total memory usage of features in bytes
|
||||
|
||||
@classmethod
|
||||
def compute(
|
||||
cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None
|
||||
) -> "TrainingDatasetStatistics":
|
||||
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
|
||||
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
|
||||
|
||||
# Sample counts
|
||||
n_samples = len(categorical_dataset)
|
||||
n_training_samples = len(categorical_dataset.y.train)
|
||||
n_test_samples = len(categorical_dataset.y.test)
|
||||
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
|
||||
|
||||
# Feature statistics
|
||||
n_features = len(categorical_dataset.X.data.columns)
|
||||
feature_names = list(categorical_dataset.X.data.columns)
|
||||
size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
|
||||
|
||||
# Class information
|
||||
class_labels = categorical_dataset.y.labels
|
||||
class_intervals = categorical_dataset.y.intervals
|
||||
n_classes = len(class_labels)
|
||||
|
||||
# Training class distribution
|
||||
train_y_series = pd.Series(categorical_dataset.y.train)
|
||||
train_counts = train_y_series.value_counts().sort_index()
|
||||
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
|
||||
train_total = sum(training_class_counts.values())
|
||||
training_class_distribution = {
|
||||
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
|
||||
}
|
||||
|
||||
# Test class distribution
|
||||
test_y_series = pd.Series(categorical_dataset.y.test)
|
||||
test_counts = test_y_series.value_counts().sort_index()
|
||||
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
|
||||
test_total = sum(test_class_counts.values())
|
||||
test_class_distribution = {
|
||||
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
|
||||
}
|
||||
|
||||
# Raw value statistics
|
||||
raw_values = categorical_dataset.y.raw_values
|
||||
raw_value_min = float(raw_values.min())
|
||||
raw_value_max = float(raw_values.max())
|
||||
raw_value_mean = float(raw_values.mean())
|
||||
raw_value_median = float(raw_values.median())
|
||||
raw_value_std = float(raw_values.std())
|
||||
|
||||
# Imbalance ratio (smallest class / largest class across both splits)
|
||||
all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
|
||||
nonzero_counts = [c for c in all_counts if c > 0]
|
||||
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
|
||||
|
||||
return cls(
|
||||
n_samples=n_samples,
|
||||
n_features=n_features,
|
||||
feature_names=feature_names,
|
||||
n_training_samples=n_training_samples,
|
||||
n_test_samples=n_test_samples,
|
||||
train_test_ratio=train_test_ratio,
|
||||
class_labels=class_labels,
|
||||
class_intervals=class_intervals,
|
||||
n_classes=n_classes,
|
||||
training_class_counts=training_class_counts,
|
||||
training_class_distribution=training_class_distribution,
|
||||
test_class_counts=test_class_counts,
|
||||
test_class_distribution=test_class_distribution,
|
||||
raw_value_min=raw_value_min,
|
||||
raw_value_max=raw_value_max,
|
||||
raw_value_mean=raw_value_mean,
|
||||
raw_value_median=raw_value_median,
|
||||
raw_value_std=raw_value_std,
|
||||
imbalance_ratio=imbalance_ratio,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CVMetricStatistics:
|
||||
best_score: float
|
||||
mean_score: float
|
||||
std_score: float
|
||||
worst_score: float
|
||||
median_score: float
|
||||
mean_cv_std: float | None
|
||||
|
||||
@classmethod
|
||||
def compute(cls, result: TrainingResult, metric: str) -> "CVMetricStatistics":
|
||||
"""Get cross-validation statistics for a metric."""
|
||||
score_col = f"mean_test_{metric}"
|
||||
std_col = f"std_test_{metric}"
|
||||
|
||||
if score_col not in result.results.columns:
|
||||
raise ValueError(f"Metric {metric} not found in results.")
|
||||
|
||||
best_score = result.results[score_col].max()
|
||||
mean_score = result.results[score_col].mean()
|
||||
std_score = result.results[score_col].std()
|
||||
worst_score = result.results[score_col].min()
|
||||
median_score = result.results[score_col].median()
|
||||
|
||||
mean_cv_std = None
|
||||
if std_col in result.results.columns:
|
||||
mean_cv_std = result.results[std_col].mean()
|
||||
|
||||
return CVMetricStatistics(
|
||||
best_score=best_score,
|
||||
mean_score=mean_score,
|
||||
std_score=std_score,
|
||||
worst_score=worst_score,
|
||||
median_score=median_score,
|
||||
mean_cv_std=mean_cv_std,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParameterSpaceSummary:
|
||||
parameter: str
|
||||
type: Literal["Numeric", "Categorical"]
|
||||
min: float | None
|
||||
max: float | None
|
||||
mean: float | None
|
||||
unique_values: int
|
||||
|
||||
@classmethod
|
||||
def compute(cls, result: TrainingResult, param_col: str) -> "ParameterSpaceSummary":
|
||||
param_name = param_col.replace("param_", "")
|
||||
param_values = result.results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
return ParameterSpaceSummary(
|
||||
parameter=param_name,
|
||||
type="Numeric",
|
||||
min=param_values.min(),
|
||||
max=param_values.max(),
|
||||
mean=param_values.mean(),
|
||||
unique_values=param_values.nunique(),
|
||||
)
|
||||
else:
|
||||
unique_vals = param_values.unique()
|
||||
return ParameterSpaceSummary(
|
||||
parameter=param_name,
|
||||
type="Categorical",
|
||||
min=None,
|
||||
max=None,
|
||||
mean=None,
|
||||
unique_values=len(unique_vals),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CVResultsStatistics:
|
||||
metrics: dict[str, CVMetricStatistics]
|
||||
parameter_summary: list[ParameterSpaceSummary]
|
||||
|
||||
@classmethod
|
||||
def compute(cls, result: TrainingResult) -> "CVResultsStatistics":
|
||||
"""Get cross-validation statistics for a metric."""
|
||||
metrics = result.available_metrics
|
||||
metric_stats: dict[str, CVMetricStatistics] = {}
|
||||
for metric in metrics:
|
||||
metric_stats[metric] = CVMetricStatistics.compute(result, metric)
|
||||
|
||||
param_cols = [col for col in result.results.columns if col.startswith("param_") and col != "params"]
|
||||
summary_data = []
|
||||
for param_col in param_cols:
|
||||
summary_data.append(ParameterSpaceSummary.compute(result, param_col))
|
||||
|
||||
return CVResultsStatistics(metrics=metric_stats, parameter_summary=summary_data)
|
||||
|
||||
def metrics_to_dataframe(self) -> pd.DataFrame:
|
||||
"""Convert metric statistics to a DataFrame."""
|
||||
data = defaultdict(list)
|
||||
for metric, stats in self.metrics.items():
|
||||
data["Metric"].append(metric)
|
||||
data["Best Score"].append(stats.best_score)
|
||||
data["Mean Score"].append(stats.mean_score)
|
||||
data["Std Dev"].append(stats.std_score)
|
||||
data["Worst Score"].append(stats.worst_score)
|
||||
data["Median Score"].append(stats.median_score)
|
||||
data["Mean CV Std Dev"].append(stats.mean_cv_std)
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def parameters_to_dataframe(self) -> pd.DataFrame:
|
||||
"""Convert parameter summary to a DataFrame."""
|
||||
return pd.DataFrame([asdict(p) for p in self.parameter_summary])
|
||||
|
|
@ -1,232 +0,0 @@
|
|||
"""Training utilities for dashboard."""
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.utils.data import TrainingResult
|
||||
|
||||
|
||||
def format_metric_name(metric: str) -> str:
|
||||
"""Format metric name for display.
|
||||
|
||||
Args:
|
||||
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
|
||||
|
||||
Returns:
|
||||
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
|
||||
|
||||
"""
|
||||
# Split by underscore and capitalize each part
|
||||
parts = metric.split("_")
|
||||
# Special handling for F1
|
||||
formatted_parts = []
|
||||
for part in parts:
|
||||
if part.lower() == "f1":
|
||||
formatted_parts.append("F1")
|
||||
else:
|
||||
formatted_parts.append(part.capitalize())
|
||||
return " ".join(formatted_parts)
|
||||
|
||||
|
||||
def get_available_metrics(results: pd.DataFrame) -> list[str]:
|
||||
"""Get list of available metrics from results.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
|
||||
Returns:
|
||||
List of metric names (without 'mean_test_' prefix).
|
||||
|
||||
"""
|
||||
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
|
||||
return [col.replace("mean_test_", "") for col in score_cols]
|
||||
|
||||
|
||||
def load_best_model(result: TrainingResult):
|
||||
"""Load the best model from a training result.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object.
|
||||
|
||||
Returns:
|
||||
The loaded model object, or None if loading fails.
|
||||
|
||||
"""
|
||||
model_file = result.path / "best_estimator_model.pkl"
|
||||
if not model_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(model_file, "rb") as f:
|
||||
model = pickle.load(f)
|
||||
return model
|
||||
except Exception as e:
|
||||
st.error(f"Error loading model: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_model_state(result: TrainingResult) -> xr.Dataset | None:
|
||||
"""Load the model state from a training result.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object.
|
||||
|
||||
Returns:
|
||||
xarray Dataset with model state, or None if not available.
|
||||
|
||||
"""
|
||||
state_file = result.path / "best_estimator_state.nc"
|
||||
if not state_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
state = xr.open_dataset(state_file, engine="h5netcdf")
|
||||
return state
|
||||
except Exception as e:
|
||||
st.error(f"Error loading model state: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_predictions(result: TrainingResult) -> pd.DataFrame | None:
|
||||
"""Load predictions from a training result.
|
||||
|
||||
Args:
|
||||
result: TrainingResult object.
|
||||
|
||||
Returns:
|
||||
DataFrame with predictions, or None if not available.
|
||||
|
||||
"""
|
||||
preds_file = result.path / "predicted_probabilities.parquet"
|
||||
if not preds_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
preds = pd.read_parquet(preds_file)
|
||||
return preds
|
||||
except Exception as e:
|
||||
st.error(f"Error loading predictions: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Get summary of parameter space explored.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
|
||||
Returns:
|
||||
DataFrame with parameter ranges and statistics.
|
||||
|
||||
"""
|
||||
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
|
||||
|
||||
summary_data = []
|
||||
for param_col in param_cols:
|
||||
param_name = param_col.replace("param_", "")
|
||||
param_values = results[param_col].dropna()
|
||||
|
||||
if pd.api.types.is_numeric_dtype(param_values):
|
||||
summary_data.append(
|
||||
{
|
||||
"Parameter": param_name,
|
||||
"Type": "Numeric",
|
||||
"Min": f"{param_values.min():.2e}",
|
||||
"Max": f"{param_values.max():.2e}",
|
||||
"Mean": f"{param_values.mean():.2e}",
|
||||
"Unique Values": param_values.nunique(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
unique_vals = param_values.unique()
|
||||
summary_data.append(
|
||||
{
|
||||
"Parameter": param_name,
|
||||
"Type": "Categorical",
|
||||
"Min": "-",
|
||||
"Max": "-",
|
||||
"Mean": "-",
|
||||
"Unique Values": len(unique_vals),
|
||||
}
|
||||
)
|
||||
|
||||
return pd.DataFrame(summary_data)
|
||||
|
||||
|
||||
def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict:
|
||||
"""Get cross-validation statistics for a metric.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
metric: Metric name (without 'mean_test_' prefix).
|
||||
|
||||
Returns:
|
||||
Dictionary with CV statistics.
|
||||
|
||||
"""
|
||||
score_col = f"mean_test_{metric}"
|
||||
std_col = f"std_test_{metric}"
|
||||
|
||||
if score_col not in results.columns:
|
||||
return {}
|
||||
|
||||
stats = {
|
||||
"best_score": results[score_col].max(),
|
||||
"mean_score": results[score_col].mean(),
|
||||
"std_score": results[score_col].std(),
|
||||
"worst_score": results[score_col].min(),
|
||||
"median_score": results[score_col].median(),
|
||||
}
|
||||
|
||||
if std_col in results.columns:
|
||||
stats["mean_cv_std"] = results[std_col].mean()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame:
|
||||
"""Prepare results dataframe with binned columns for plotting.
|
||||
|
||||
Args:
|
||||
results: DataFrame with CV results.
|
||||
k_bin_width: Width of bins for initial_K parameter.
|
||||
|
||||
Returns:
|
||||
DataFrame with added binned columns.
|
||||
|
||||
"""
|
||||
results_copy = results.copy()
|
||||
|
||||
# Check if we have the parameters
|
||||
if "param_initial_K" in results.columns:
|
||||
# Bin initial_K
|
||||
k_values = results["param_initial_K"].dropna()
|
||||
if len(k_values) > 0:
|
||||
k_min = k_values.min()
|
||||
k_max = k_values.max()
|
||||
k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width)
|
||||
results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False)
|
||||
|
||||
if "param_eps_cl" in results.columns:
|
||||
# Create logarithmic bins for eps_cl
|
||||
eps_cl_values = results["param_eps_cl"].dropna()
|
||||
if len(eps_cl_values) > 0 and eps_cl_values.min() > 0:
|
||||
eps_cl_min = eps_cl_values.min()
|
||||
eps_cl_max = eps_cl_values.max()
|
||||
eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10)
|
||||
results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins)
|
||||
|
||||
if "param_eps_e" in results.columns:
|
||||
# Create logarithmic bins for eps_e
|
||||
eps_e_values = results["param_eps_e"].dropna()
|
||||
if len(eps_e_values) > 0 and eps_e_values.min() > 0:
|
||||
eps_e_min = eps_e_values.min()
|
||||
eps_e_max = eps_e_values.max()
|
||||
eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10)
|
||||
results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins)
|
||||
|
||||
return results_copy
|
||||
|
|
@ -1,147 +1,4 @@
|
|||
"""Data utilities for Entropice dashboard."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import antimeridian
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import toml
|
||||
import xarray as xr
|
||||
from shapely.geometry import shape
|
||||
|
||||
import entropice.utils.paths
|
||||
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingResult:
|
||||
"""Simple wrapper of training result data."""
|
||||
|
||||
name: str
|
||||
path: Path
|
||||
settings: dict
|
||||
results: pd.DataFrame
|
||||
created_at: float
|
||||
|
||||
def get_display_name(self, format_type: str = "task_first") -> str:
|
||||
"""Get formatted display name for the training result.
|
||||
|
||||
Args:
|
||||
format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state)
|
||||
|
||||
Returns:
|
||||
Formatted name string
|
||||
|
||||
"""
|
||||
task = self.settings.get("task", "Unknown").capitalize()
|
||||
model = self.settings.get("model", "Unknown").upper()
|
||||
grid = self.settings.get("grid", "Unknown").capitalize()
|
||||
level = self.settings.get("level", "Unknown")
|
||||
timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
if format_type == "model_first":
|
||||
return f"{model} - {task} - {grid}-{level} ({timestamp})"
|
||||
else: # task_first
|
||||
return f"{task} - {model} - {grid}-{level} ({timestamp})"
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, result_path: Path) -> "TrainingResult":
|
||||
"""Load a TrainingResult from a given result directory path."""
|
||||
result_file = result_path / "search_results.parquet"
|
||||
preds_file = result_path / "predicted_probabilities.parquet"
|
||||
settings_file = result_path / "search_settings.toml"
|
||||
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
|
||||
raise FileNotFoundError(f"Missing required files in {result_path}")
|
||||
|
||||
created_at = result_path.stat().st_ctime
|
||||
settings = toml.load(settings_file)["settings"]
|
||||
results = pd.read_parquet(result_file)
|
||||
|
||||
# Name should be "task model grid-level (created_at)"
|
||||
model = settings.get("model", "Unknown").upper()
|
||||
name = (
|
||||
f"{settings.get('task', 'Unknown').capitalize()} -"
|
||||
f" {model} -"
|
||||
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
|
||||
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
path=result_path,
|
||||
settings=settings,
|
||||
results=results,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
def _fix_hex_geometry(geom):
|
||||
"""Fix hexagon geometry crossing the antimeridian."""
|
||||
try:
|
||||
return shape(antimeridian.fix_shape(geom))
|
||||
except ValueError as e:
|
||||
st.error(f"Error fixing geometry: {e}")
|
||||
return geom
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_training_results() -> list[TrainingResult]:
|
||||
"""Load all training results from the results directory."""
|
||||
results_dir = entropice.utils.paths.RESULTS_DIR
|
||||
training_results: list[TrainingResult] = []
|
||||
for result_path in results_dir.iterdir():
|
||||
if not result_path.is_dir():
|
||||
continue
|
||||
|
||||
try:
|
||||
training_result = TrainingResult.from_path(result_path)
|
||||
training_results.append(training_result)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
# Sort by creation time (most recent first)
|
||||
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
|
||||
return training_results
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingDataset]:
|
||||
"""Load training data for all three tasks.
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
|
||||
Returns:
|
||||
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
|
||||
|
||||
"""
|
||||
return {
|
||||
"binary": e.create_cat_training_dataset("binary", device="cpu"),
|
||||
"count": e.create_cat_training_dataset("count", device="cpu"),
|
||||
"density": e.create_cat_training_dataset("density", device="cpu"),
|
||||
}
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_source_data(e: DatasetEnsemble, source: str):
|
||||
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
|
||||
|
||||
Args:
|
||||
e: DatasetEnsemble object.
|
||||
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
|
||||
|
||||
Returns:
|
||||
xarray.Dataset with the raw data for the specified source.
|
||||
|
||||
"""
|
||||
targets = e._read_target()
|
||||
|
||||
# Load the member data lazily to get metadata
|
||||
ds = e._read_member(source, targets, lazy=False)
|
||||
|
||||
return ds, targets
|
||||
|
||||
|
||||
def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray:
|
||||
|
|
@ -203,7 +60,7 @@ def extract_embedding_features(
|
|||
band=("feature", [f.split("_")[2] for f in embedding_features]),
|
||||
year=("feature", [f.split("_")[3] for f in embedding_features]),
|
||||
)
|
||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010
|
||||
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature")
|
||||
return embedding_feature_array
|
||||
|
||||
|
||||
|
|
@ -419,9 +276,9 @@ def extract_era5_features(
|
|||
era5_features_array = era5_features_array.assign_coords(
|
||||
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
|
||||
)
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature")
|
||||
else:
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010
|
||||
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature")
|
||||
|
||||
return era5_features_array
|
||||
|
||||
|
|
@ -472,7 +329,7 @@ def extract_arcticdem_features(
|
|||
variable=("feature", [_extract_var_name(f) for f in arcticdem_features]),
|
||||
agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]),
|
||||
)
|
||||
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010
|
||||
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature")
|
||||
return arcticdem_feature_array
|
||||
|
||||
|
||||
|
|
@ -503,16 +360,3 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea
|
|||
# Extract the feature importance for common features
|
||||
common_feature_array = importance_array.sel(feature=common_features)
|
||||
return common_feature_array
|
||||
|
||||
|
||||
def get_members_from_settings(settings: dict) -> list[str]:
|
||||
"""Extract the list of dataset members used in training from settings.
|
||||
|
||||
Args:
|
||||
settings: Training settings dictionary.
|
||||
|
||||
Returns:
|
||||
List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']).
|
||||
|
||||
"""
|
||||
return settings.get("members", [])
|
||||
|
|
@ -3,7 +3,6 @@
|
|||
import streamlit as st
|
||||
import xarray as xr
|
||||
|
||||
from entropice.dashboard.plots.colors import generate_unified_colormap
|
||||
from entropice.dashboard.plots.model_state import (
|
||||
plot_arcticdem_heatmap,
|
||||
plot_arcticdem_summary,
|
||||
|
|
@ -17,6 +16,7 @@ from entropice.dashboard.plots.model_state import (
|
|||
plot_era5_time_heatmap,
|
||||
plot_top_features,
|
||||
)
|
||||
from entropice.dashboard.utils.colors import generate_unified_colormap
|
||||
from entropice.dashboard.utils.data import (
|
||||
extract_arcticdem_features,
|
||||
extract_common_features,
|
||||
|
|
@ -8,16 +8,17 @@ import pandas as pd
|
|||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from entropice.dashboard.plots.colors import get_palette
|
||||
from entropice.dashboard.utils.data import load_all_training_results
|
||||
from entropice.dashboard.utils.colors import get_palette
|
||||
from entropice.dashboard.utils.loaders import load_all_training_results
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.utils.types import Grid, TargetDataset, Task
|
||||
|
||||
|
||||
# Type definitions for dataset statistics
|
||||
class GridConfig(TypedDict):
|
||||
"""Grid configuration specification with metadata."""
|
||||
|
||||
grid: str
|
||||
grid: Grid
|
||||
level: int
|
||||
grid_name: str
|
||||
grid_sort_key: str
|
||||
|
|
@ -116,7 +117,7 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
|||
Results are cached to avoid redundant computations across different tabs.
|
||||
"""
|
||||
# Define all possible grid configurations
|
||||
grid_configs_raw = [
|
||||
grid_configs_raw: list[tuple[Grid, int]] = [
|
||||
("hex", 3),
|
||||
("hex", 4),
|
||||
("hex", 5),
|
||||
|
|
@ -144,8 +145,8 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache:
|
|||
|
||||
# Compute sample counts
|
||||
sample_counts: list[SampleCountData] = []
|
||||
target_datasets = ["darts_rts", "darts_mllabels"]
|
||||
tasks = ["binary", "count", "density"]
|
||||
target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
|
||||
tasks: list[Task] = ["binary", "count", "density"]
|
||||
|
||||
for grid_config in grid_configs:
|
||||
for target in target_datasets:
|
||||
|
|
@ -810,25 +811,32 @@ def render_experiment_results(training_results):
|
|||
|
||||
def render_overview_page():
|
||||
"""Render the Overview page of the dashboard."""
|
||||
st.title("🏡 Training Results Overview")
|
||||
st.title("🏡 Entropice Overview")
|
||||
st.markdown(
|
||||
"""
|
||||
Welcome to the Entropice Dashboard! This overview page provides a summary of your
|
||||
training experiments and dataset analyses.
|
||||
|
||||
Use the sections below to explore training results, dataset sample counts,
|
||||
and feature configurations.
|
||||
"""
|
||||
)
|
||||
# Load training results
|
||||
training_results = load_all_training_results()
|
||||
|
||||
if not training_results:
|
||||
st.warning("No training results found. Please run some training experiments first.")
|
||||
return
|
||||
else:
|
||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||
|
||||
st.write(f"Found **{len(training_results)}** training result(s)")
|
||||
st.divider()
|
||||
|
||||
st.divider()
|
||||
# Render training results sections
|
||||
render_training_results_summary(training_results)
|
||||
|
||||
# Render training results sections
|
||||
render_training_results_summary(training_results)
|
||||
st.divider()
|
||||
|
||||
st.divider()
|
||||
|
||||
render_experiment_results(training_results)
|
||||
render_experiment_results(training_results)
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ from entropice.dashboard.plots.source_data import (
|
|||
render_era5_plots,
|
||||
)
|
||||
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
|
||||
from entropice.dashboard.utils.data import load_all_training_data, load_source_data
|
||||
from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
|
||||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.spatial import grids
|
||||
|
||||
|
|
@ -60,21 +60,12 @@ def render_training_data_page():
|
|||
# Members selection
|
||||
st.subheader("Dataset Members")
|
||||
|
||||
# Check if AlphaEarth should be disabled
|
||||
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
|
||||
|
||||
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
selected_members = []
|
||||
|
||||
for member in all_members:
|
||||
if member == "AlphaEarth" and disable_alphaearth:
|
||||
# Show disabled checkbox with explanation
|
||||
st.checkbox(
|
||||
member, value=False, disabled=True, help=f"AlphaEarth is not available for {grid} level {level}"
|
||||
)
|
||||
else:
|
||||
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||
selected_members.append(member)
|
||||
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
|
||||
selected_members.append(member)
|
||||
|
||||
# Form submit button
|
||||
load_button = st.form_submit_button(
|
||||
|
|
@ -6,7 +6,6 @@ Date: October 2025
|
|||
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
|
||||
import cudf
|
||||
import cuml.cluster
|
||||
|
|
@ -25,6 +24,7 @@ from rich.progress import track
|
|||
from entropice.spatial import grids
|
||||
from entropice.utils import codecs
|
||||
from entropice.utils.paths import get_annual_embeddings_file, get_embeddings_store
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
# Filter out the GeoDataFrame.swapaxes deprecation warning
|
||||
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
|
||||
|
|
@ -55,11 +55,11 @@ def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.D
|
|||
|
||||
|
||||
@cli.command()
|
||||
def download(grid: Literal["hex", "healpix"], level: int):
|
||||
def download(grid: Grid, level: int):
|
||||
"""Extract satellite embeddings from Google Earth Engine and map them to a grid.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
|
|
@ -126,11 +126,11 @@ def download(grid: Literal["hex", "healpix"], level: int):
|
|||
|
||||
|
||||
@cli.command()
|
||||
def combine_to_zarr(grid: Literal["hex", "healpix"], level: int):
|
||||
def combine_to_zarr(grid: Grid, level: int):
|
||||
"""Combine yearly embeddings parquet files into a single zarr store.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import datetime
|
|||
import multiprocessing as mp
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Literal
|
||||
|
||||
import cupy as cp
|
||||
import cupy_xarray
|
||||
|
|
@ -32,6 +31,7 @@ from entropice.spatial import grids, watermask
|
|||
from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_grid
|
||||
from entropice.utils import codecs
|
||||
from entropice.utils.paths import get_arcticdem_stores
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
traceback.install(show_locals=True, suppress=[cyclopts])
|
||||
pretty.install()
|
||||
|
|
@ -330,7 +330,7 @@ def _open_adem():
|
|||
|
||||
|
||||
@cli.command()
|
||||
def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20):
|
||||
def aggregate(grid: Grid, level: int, concurrent_partitions: int = 20):
|
||||
mp.set_start_method("forkserver", force=True)
|
||||
with (
|
||||
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ Author: Tobias Hölzer
|
|||
Date: October 2025
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import cyclopts
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
|
|
@ -17,6 +15,7 @@ from stopuhr import stopwatch
|
|||
|
||||
from entropice.spatial import grids
|
||||
from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
|
@ -25,11 +24,11 @@ cli = cyclopts.App(name="darts-rts")
|
|||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
||||
def extract_darts_rts(grid: Grid, level: int):
|
||||
"""Extract RTS labels from DARTS dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
"""
|
||||
|
|
@ -91,7 +90,7 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
|
|||
|
||||
|
||||
@cli.command()
|
||||
def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int):
|
||||
def extract_darts_mllabels(grid: Grid, level: int):
|
||||
with stopwatch("Load data"):
|
||||
grid_gdf = grids.open(grid, level)
|
||||
darts_mllabels = (
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_g
|
|||
from entropice.spatial.xvec import to_xvec
|
||||
from entropice.utils import codecs
|
||||
from entropice.utils.paths import FIGURES_DIR, get_era5_stores
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
|
||||
pretty.install()
|
||||
|
|
@ -579,7 +580,7 @@ def enrich(n_workers: int = 10, monthly: bool = True, yearly: bool = True, daily
|
|||
|
||||
@cli.command
|
||||
def viz(
|
||||
grid: Literal["hex", "healpix"] | None = None,
|
||||
grid: Grid | None = None,
|
||||
level: int | None = None,
|
||||
agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly",
|
||||
high_qual: bool = False,
|
||||
|
|
@ -587,7 +588,7 @@ def viz(
|
|||
"""Visualize a small overview of ERA5 variables for a given aggregation.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"], optional): Grid type for spatial representation.
|
||||
grid (Grid, optional): Grid type for spatial representation.
|
||||
If provided along with level, the ERA5 data will be decoded onto the specified grid.
|
||||
level (int, optional): Level of the grid for spatial representation.
|
||||
If provided along with grid, the ERA5 data will be decoded onto the specified grid.
|
||||
|
|
@ -653,7 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
|
|||
|
||||
@cli.command
|
||||
def spatial_agg(
|
||||
grid: Literal["hex", "healpix"],
|
||||
grid: Grid,
|
||||
level: int,
|
||||
concurrent_partitions: int = 20,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import hashlib
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from functools import cached_property
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import cupy as cp
|
||||
|
|
@ -32,6 +32,7 @@ from sklearn import set_config
|
|||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import entropice.utils.paths
|
||||
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
|
@ -41,6 +42,27 @@ set_config(array_api_dispatch=True)
|
|||
sns.set_theme("talk", "whitegrid")
|
||||
|
||||
|
||||
covcol: dict[TargetDataset, str] = {
|
||||
"darts_rts": "darts_has_coverage",
|
||||
"darts_mllabels": "dartsml_has_coverage",
|
||||
}
|
||||
|
||||
taskcol: dict[Task, dict[TargetDataset, str]] = {
|
||||
"binary": {
|
||||
"darts_rts": "darts_has_rts",
|
||||
"darts_mllabels": "dartsml_has_rts",
|
||||
},
|
||||
"count": {
|
||||
"darts_rts": "darts_rts_count",
|
||||
"darts_mllabels": "dartsml_rts_count",
|
||||
},
|
||||
"density": {
|
||||
"darts_rts": "darts_rts_density",
|
||||
"darts_mllabels": "dartsml_rts_density",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
|
||||
time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
|
||||
if temporal == "yearly":
|
||||
|
|
@ -53,10 +75,6 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "
|
|||
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
|
||||
|
||||
|
||||
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
type Task = Literal["binary", "count", "density"]
|
||||
|
||||
|
||||
def bin_values(
|
||||
values: pd.Series,
|
||||
task: Literal["count", "density"],
|
||||
|
|
@ -144,7 +162,7 @@ class DatasetInputs:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class CategoricalTrainingDataset:
|
||||
dataset: pd.DataFrame
|
||||
dataset: gpd.GeoDataFrame
|
||||
X: DatasetInputs
|
||||
y: DatasetLabels
|
||||
z: pd.Series
|
||||
|
|
@ -162,12 +180,12 @@ class DatasetStats(TypedDict):
|
|||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class DatasetEnsemble:
|
||||
grid: Literal["hex", "healpix"]
|
||||
grid: Grid
|
||||
level: int
|
||||
target: Literal["darts_rts", "darts_mllabels"]
|
||||
members: list[L2Dataset] = field(
|
||||
members: list[L2SourceDataset] = field(
|
||||
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
|
||||
)
|
||||
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
|
||||
|
|
@ -186,19 +204,12 @@ class DatasetEnsemble:
|
|||
|
||||
@property
|
||||
def covcol(self) -> str:
|
||||
return "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
|
||||
return covcol[self.target]
|
||||
|
||||
def taskcol(self, task: Task) -> str:
|
||||
if task == "binary":
|
||||
return "dartsml_has_rts" if self.target == "darts_mllabels" else "darts_has_rts"
|
||||
elif task == "count":
|
||||
return "dartsml_rts_count" if self.target == "darts_mllabels" else "darts_rts_count"
|
||||
elif task == "density":
|
||||
return "dartsml_rts_density" if self.target == "darts_mllabels" else "darts_rts_density"
|
||||
else:
|
||||
raise ValueError(f"Invalid task: {task}")
|
||||
return taskcol[task][self.target]
|
||||
|
||||
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||
def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
|
||||
if member == "AlphaEarth":
|
||||
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
|
||||
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
|
||||
|
|
@ -340,8 +351,9 @@ class DatasetEnsemble:
|
|||
|
||||
print(f"=== Total number of features in dataset: {stats['total_features']}")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame:
|
||||
def create(
|
||||
self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r"
|
||||
) -> gpd.GeoDataFrame:
|
||||
# n: no cache, o: overwrite cache, r: read cache if exists
|
||||
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
|
||||
if cache_mode == "r" and cache_file.exists():
|
||||
|
|
@ -425,26 +437,14 @@ class DatasetEnsemble:
|
|||
print(f"Saved dataset to cache at {cache_file}.")
|
||||
yield dataset
|
||||
|
||||
def create_cat_training_dataset(
|
||||
self, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||
def _cat_and_split(
|
||||
self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||
) -> CategoricalTrainingDataset:
|
||||
"""Create a categorical dataset for training.
|
||||
|
||||
Args:
|
||||
task (Task): Task type.
|
||||
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
|
||||
|
||||
Returns:
|
||||
CategoricalTrainingDataset: The prepared categorical training dataset.
|
||||
|
||||
"""
|
||||
covcol = "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
|
||||
dataset = self.create(filter_target_col=covcol)
|
||||
taskcol = self.taskcol(task)
|
||||
|
||||
valid_labels = dataset[taskcol].notna()
|
||||
|
||||
cols_to_drop = {"geometry", taskcol, covcol}
|
||||
cols_to_drop = {"geometry", taskcol, self.covcol}
|
||||
cols_to_drop |= {
|
||||
col
|
||||
for col in dataset.columns
|
||||
|
|
@ -505,3 +505,19 @@ class DatasetEnsemble:
|
|||
z=model_labels,
|
||||
split=split,
|
||||
)
|
||||
|
||||
def create_cat_training_dataset(
|
||||
self, task: Task, device: Literal["cpu", "cuda", "torch"]
|
||||
) -> CategoricalTrainingDataset:
|
||||
"""Create a categorical dataset for training.
|
||||
|
||||
Args:
|
||||
task (Task): Task type.
|
||||
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
|
||||
|
||||
Returns:
|
||||
CategoricalTrainingDataset: The prepared categorical training dataset.
|
||||
|
||||
"""
|
||||
dataset = self.create(filter_target_col=self.covcol)
|
||||
return self._cat_and_split(dataset, task, device)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import pickle
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal
|
||||
|
||||
import cupy as cp
|
||||
import cyclopts
|
||||
|
|
@ -23,6 +22,7 @@ from xgboost.sklearn import XGBClassifier
|
|||
from entropice.ml.dataset import DatasetEnsemble
|
||||
from entropice.ml.inference import predict_proba
|
||||
from entropice.utils.paths import get_cv_results_dir
|
||||
from entropice.utils.types import Model, Task
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
|
@ -50,12 +50,20 @@ _metrics = {
|
|||
|
||||
|
||||
@cyclopts.Parameter("*")
|
||||
@dataclass
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CVSettings:
|
||||
n_iter: int = 2000
|
||||
robust: bool = False
|
||||
task: Literal["binary", "count", "density"] = "binary"
|
||||
model: Literal["espa", "xgboost", "rf", "knn"] = "espa"
|
||||
task: Task = "binary"
|
||||
model: Model = "espa"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TrainingSettings(DatasetEnsemble, CVSettings):
|
||||
param_grid: dict
|
||||
cv_splits: int
|
||||
metrics: list[str]
|
||||
classes: list[str]
|
||||
|
||||
|
||||
def _create_clf(
|
||||
|
|
@ -141,7 +149,7 @@ def random_cv(
|
|||
"""Perform random cross-validation on the training dataset.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
|
||||
robust (bool, optional): Whether to use robust training. Defaults to False.
|
||||
|
|
@ -202,18 +210,18 @@ def random_cv(
|
|||
# Store the search settings
|
||||
# First convert the param_grid distributions to a serializable format
|
||||
param_grid_serializable = _serialize_param_grid(param_grid)
|
||||
combined_settings = {
|
||||
combined_settings = TrainingSettings(
|
||||
**asdict(settings),
|
||||
**asdict(dataset_ensemble),
|
||||
"param_grid": param_grid_serializable,
|
||||
"cv_splits": cv.get_n_splits(),
|
||||
"metrics": metrics,
|
||||
"classes": training_data.y.labels,
|
||||
}
|
||||
param_grid=param_grid_serializable,
|
||||
cv_splits=cv.get_n_splits(),
|
||||
metrics=metrics,
|
||||
classes=training_data.y.labels,
|
||||
)
|
||||
settings_file = results_dir / "search_settings.toml"
|
||||
print(f"Storing search settings to {settings_file}")
|
||||
with open(settings_file, "w") as f:
|
||||
toml.dump({"settings": combined_settings}, f)
|
||||
toml.dump({"settings": asdict(combined_settings)}, f)
|
||||
|
||||
# Store the best estimator model
|
||||
best_model_file = results_dir / "best_estimator_model.pkl"
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from stopuhr import stopwatch
|
|||
from xdggs.healpix import HealpixInfo
|
||||
|
||||
from entropice.spatial import grids
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -585,7 +586,7 @@ def aggregate_raster_into_grid(
|
|||
raster: xr.Dataset | Callable[[], xr.Dataset],
|
||||
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
|
||||
aggregations: _Aggregations | Literal["interpolate"],
|
||||
grid: Literal["hex", "healpix"],
|
||||
grid: Grid,
|
||||
level: int,
|
||||
n_partitions: int | None = 20,
|
||||
concurrent_partitions: int = 5,
|
||||
|
|
@ -599,7 +600,7 @@ def aggregate_raster_into_grid(
|
|||
If a list of GeoDataFrames is provided, each will be processed as a separate partition.
|
||||
No further partitioning will be done and the n_partitions argument will be ignored.
|
||||
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
|
||||
grid (Literal["hex", "healpix"]): The type of grid to use.
|
||||
grid (Grid): The type of grid to use.
|
||||
level (int): The level of the grid.
|
||||
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
|
||||
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ Date: 09. June 2025
|
|||
"""
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Literal
|
||||
|
||||
import cartopy.crs as ccrs
|
||||
import cartopy.feature as cfeature
|
||||
|
|
@ -26,16 +25,17 @@ from stopuhr import stopwatch
|
|||
from xdggs.healpix import HealpixInfo
|
||||
|
||||
from entropice.utils.paths import get_grid_file, get_grid_viz_file, watermask_file
|
||||
from entropice.utils.types import Grid
|
||||
|
||||
traceback.install()
|
||||
pretty.install()
|
||||
|
||||
|
||||
def open(grid: Literal["hex", "healpix"], level: int):
|
||||
def open(grid: Grid, level: int):
|
||||
"""Open a saved grid from parquet file.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
Returns:
|
||||
|
|
@ -47,11 +47,11 @@ def open(grid: Literal["hex", "healpix"], level: int):
|
|||
return grid
|
||||
|
||||
|
||||
def get_cell_ids(grid: Literal["hex", "healpix"], level: int):
|
||||
def get_cell_ids(grid: Grid, level: int):
|
||||
"""Get the cell IDs of a saved grid.
|
||||
|
||||
Args:
|
||||
grid (Literal["hex", "healpix"]): The grid type to use.
|
||||
grid (Grid): The grid type to use.
|
||||
level (int): The grid level to use.
|
||||
|
||||
Returns:
|
||||
|
|
@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
|
|||
return fig
|
||||
|
||||
|
||||
def cli(grid: Literal["hex", "healpix"], level: int):
|
||||
def cli(grid: Grid, level: int):
|
||||
"""CLI entry point."""
|
||||
print(f"Creating {grid} grid at level {level}...")
|
||||
if grid == "hex":
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from entropice.utils.types import Grid, Task
|
||||
|
||||
DATA_DIR = (
|
||||
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
|
||||
)
|
||||
|
|
@ -42,23 +44,23 @@ dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.par
|
|||
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
|
||||
|
||||
|
||||
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str:
|
||||
def _get_gridname(grid: Grid, level: int) -> str:
|
||||
return f"permafrost_{grid}{level}"
|
||||
|
||||
|
||||
def get_grid_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
def get_grid_file(grid: Grid, level: int) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
gridfile = GRIDS_DIR / f"{gridname}_grid.parquet"
|
||||
return gridfile
|
||||
|
||||
|
||||
def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
def get_grid_viz_file(grid: Grid, level: int) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
vizfile = FIGURES_DIR / f"{gridname}_grid.png"
|
||||
return vizfile
|
||||
|
||||
|
||||
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path:
|
||||
def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
if labels:
|
||||
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
|
||||
|
|
@ -67,13 +69,13 @@ def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool
|
|||
return rtsfile
|
||||
|
||||
|
||||
def get_annual_embeddings_file(grid: Literal["hex", "healpix"], level: int, year: int) -> Path:
|
||||
def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
embfile = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
|
||||
return embfile
|
||||
|
||||
|
||||
def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
def get_embeddings_store(grid: Grid, level: int) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
|
||||
return embstore
|
||||
|
|
@ -81,9 +83,9 @@ def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path:
|
|||
|
||||
def get_era5_stores(
|
||||
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
|
||||
grid: Literal["hex", "healpix"] | None = None,
|
||||
grid: Grid | None = None,
|
||||
level: int | None = None,
|
||||
):
|
||||
) -> Path:
|
||||
if grid is None or level is None:
|
||||
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True)
|
||||
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr"
|
||||
|
|
@ -94,9 +96,9 @@ def get_era5_stores(
|
|||
|
||||
|
||||
def get_arcticdem_stores(
|
||||
grid: Literal["hex", "healpix"] | None = None,
|
||||
grid: Grid | None = None,
|
||||
level: int | None = None,
|
||||
):
|
||||
) -> Path:
|
||||
if grid is None or level is None:
|
||||
return DATA_DIR / "arcticdem32m.icechunk.zarr"
|
||||
gridname = _get_gridname(grid, level)
|
||||
|
|
@ -104,7 +106,7 @@ def get_arcticdem_stores(
|
|||
return aligned_path
|
||||
|
||||
|
||||
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path:
|
||||
def get_train_dataset_file(grid: Grid, level: int) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
|
||||
return dataset_file
|
||||
|
|
@ -122,9 +124,9 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int
|
|||
|
||||
def get_cv_results_dir(
|
||||
name: str,
|
||||
grid: Literal["hex", "healpix"],
|
||||
grid: Grid,
|
||||
level: int,
|
||||
task: Literal["binary", "count", "density"],
|
||||
task: Task,
|
||||
) -> Path:
|
||||
gridname = _get_gridname(grid, level)
|
||||
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
|
|
|
|||
11
src/entropice/utils/types.py
Normal file
11
src/entropice/utils/types.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Shared types used across the entropice codebase."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
type Grid = Literal["hex", "healpix"]
|
||||
type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"]
|
||||
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
|
||||
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
|
||||
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
|
||||
type Task = Literal["binary", "count", "density"]
|
||||
type Model = Literal["espa", "xgboost", "rf", "knn"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue