Cleanup dashboard utils and move pages to views

This commit is contained in:
Tobias Hölzer 2026-01-04 02:12:53 +01:00
parent 495ddc13f9
commit 36f8737075
26 changed files with 904 additions and 514 deletions

View file

@ -11,11 +11,11 @@ Pages:
import streamlit as st import streamlit as st
from entropice.dashboard.inference_page import render_inference_page from entropice.dashboard.views.inference_page import render_inference_page
from entropice.dashboard.model_state_page import render_model_state_page from entropice.dashboard.views.model_state_page import render_model_state_page
from entropice.dashboard.overview_page import render_overview_page from entropice.dashboard.views.overview_page import render_overview_page
from entropice.dashboard.training_analysis_page import render_training_analysis_page from entropice.dashboard.views.training_analysis_page import render_training_analysis_page
from entropice.dashboard.training_data_page import render_training_data_page from entropice.dashboard.views.training_data_page import render_training_data_page
def main(): def main():

View file

@ -12,7 +12,7 @@ import pydeck as pdk
import streamlit as st import streamlit as st
from shapely.geometry import shape from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap, get_palette from entropice.dashboard.utils.colors import get_cmap, get_palette
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble

View file

@ -7,7 +7,7 @@ import pydeck as pdk
import streamlit as st import streamlit as st
from shapely.geometry import shape from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.data import TrainingResult from entropice.dashboard.utils.data import TrainingResult

View file

@ -10,7 +10,7 @@ import streamlit as st
import xarray as xr import xarray as xr
from shapely.geometry import shape from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_cmap from entropice.dashboard.utils.colors import get_cmap
# TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations # TODO: Rename "Aggregation" to "Pixel-to-cell Aggregation" to differantiate from temporal aggregations

View file

@ -7,7 +7,7 @@ import pydeck as pdk
import streamlit as st import streamlit as st
from shapely.geometry import shape from shapely.geometry import shape
from entropice.dashboard.plots.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.ml.dataset import CategoricalTrainingDataset from entropice.ml.dataset import CategoricalTrainingDataset

View file

@ -132,7 +132,7 @@ def generate_unified_colormap(settings: dict) -> tuple[mcolors.ListedColormap, m
cmap = plt.get_cmap("viridis") cmap = plt.get_cmap("viridis")
# Sample colors evenly across the colormap # Sample colors evenly across the colormap
indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends indices = np.linspace(0.1, 0.9, n_classes) # Avoid extreme ends
base_colors = [mcolors.rgb2hex(cmap(idx)[:3]) for idx in indices] base_colors = [mcolors.rgb2hex(cmap(idx).tolist()[:3]) for idx in indices]
# Create matplotlib colormap (for ultraplot and geopandas) # Create matplotlib colormap (for ultraplot and geopandas)
matplotlib_cmap = mcolors.ListedColormap(base_colors) matplotlib_cmap = mcolors.ListedColormap(base_colors)

View file

@ -0,0 +1,96 @@
"""Formatters for dashboard display."""
from dataclasses import dataclass
from datetime import datetime
from typing import Literal
from entropice.utils.types import Grid, Model, Task
@dataclass
class ModelDisplayInfo:
keyword: Model
short: str
long: str
def __str__(self) -> str:
return f"{self.short} ({self.long})"
model_display_infos: dict[Model, ModelDisplayInfo] = {
"espa": ModelDisplayInfo(
keyword="espa", short="eSPA", long="entropy-optimal Scalable Probabilistic Approximations algorithm"
),
"xgboost": ModelDisplayInfo(keyword="xgboost", short="XGBoost", long="Extreme Gradient Boosting"),
"rf": ModelDisplayInfo(keyword="rf", short="Random Forest", long="Random Forest Classifier"),
"knn": ModelDisplayInfo(keyword="knn", short="k-NN", long="k-Nearest Neighbors Classifier"),
}
@dataclass
class TaskDisplayInfo:
keyword: Task
display_name: str
explanation: str
task_display_infos: dict[Task, TaskDisplayInfo] = {
"binary": TaskDisplayInfo(
keyword="binary",
display_name="Binary Classification",
explanation="Classify each grid cell as containing or not containing the target feature.",
),
"count": TaskDisplayInfo(
keyword="count",
display_name="Count Prediction",
explanation="Predict the number of target features present within each grid cell.",
),
"density": TaskDisplayInfo(
keyword="density",
display_name="Density Estimation",
explanation="Estimate the density of target features within each grid cell.",
),
}
@dataclass
class TrainingResultDisplayInfo:
task: Task
model: Model
grid: Grid
level: int
timestamp: datetime
def get_display_name(self, format_type: Literal["task_first", "model_first"] = "task_first") -> str:
task = self.task.capitalize()
model = self.model.upper()
grid = self.grid.capitalize()
level = self.level
timestamp = self.timestamp.strftime("%Y-%m-%d %H:%M")
if format_type == "model_first":
return f"{model} - {task} - {grid}-{level} ({timestamp})"
else: # task_first
return f"{task} - {model} - {grid}-{level} ({timestamp})"
def format_metric_name(metric: str) -> str:
"""Format metric name for display.
Args:
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
Returns:
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
"""
# Split by underscore and capitalize each part
parts = metric.split("_")
# Special handling for F1
formatted_parts = []
for part in parts:
if part.lower() == "f1":
formatted_parts.append("F1")
else:
formatted_parts.append(part.capitalize())
return " ".join(formatted_parts)

View file

@ -0,0 +1,163 @@
"""Data utilities for Entropice dashboard."""
import pickle
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import antimeridian
import geopandas as gpd
import pandas as pd
import streamlit as st
import toml
import xarray as xr
from shapely.geometry import shape
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.formatters import TrainingResultDisplayInfo
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
from entropice.ml.training import TrainingSettings
from entropice.utils.types import L2SourceDataset, Task
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
@dataclass
class TrainingResult:
path: Path
settings: TrainingSettings
results: pd.DataFrame
created_at: float
available_metrics: list[str]
@classmethod
def from_path(cls, result_path: Path) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
raise FileNotFoundError(f"Missing required files in {result_path}")
created_at = result_path.stat().st_ctime
settings = TrainingSettings(**(toml.load(settings_file)["settings"]))
results = pd.read_parquet(result_file)
available_metrics = [col.replace("mean_test_", "") for col in results.columns if col.startswith("mean_test_")]
return cls(
path=result_path,
settings=settings,
results=results,
created_at=created_at,
available_metrics=available_metrics,
)
@property
def display_info(self) -> TrainingResultDisplayInfo:
return TrainingResultDisplayInfo(
task=self.settings.task,
model=self.settings.model,
grid=self.settings.grid,
level=self.settings.level,
timestamp=datetime.fromtimestamp(self.created_at),
)
def load_best_model(self) -> object | None:
"""Load the best model from a training result."""
model_file = self.path / "best_estimator_model.pkl"
if not model_file.exists():
return None
try:
with open(model_file, "rb") as f:
model = pickle.load(f)
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
def load_model_state(self) -> xr.Dataset | None:
"""Load the model state from a training result."""
state_file = self.path / "best_estimator_state.nc"
if not state_file.exists():
return None
try:
state = xr.open_dataset(state_file, engine="h5netcdf")
return state
except Exception as e:
st.error(f"Error loading model state: {e}")
return None
def load_predictions(self) -> pd.DataFrame | None:
"""Load predictions from a training result."""
preds_file = self.path / "predicted_probabilities.parquet"
if not preds_file.exists():
return None
try:
preds = pd.read_parquet(preds_file)
return preds
except Exception as e:
st.error(f"Error loading predictions: {e}")
return None
@st.cache_data
def load_all_training_results() -> list[TrainingResult]:
results_dir = entropice.utils.paths.RESULTS_DIR
training_results: list[TrainingResult] = []
for result_path in results_dir.iterdir():
if not result_path.is_dir():
continue
training_result = TrainingResult.from_path(result_path)
training_results.append(training_result)
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
def load_all_training_data(e: DatasetEnsemble) -> dict[Task, CategoricalTrainingDataset]:
"""Load training data for all three tasks.
Args:
e: DatasetEnsemble object.
Returns:
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
return {
"binary": e.create_cat_training_dataset("binary", device="cpu"),
"count": e.create_cat_training_dataset("count", device="cpu"),
"density": e.create_cat_training_dataset("density", device="cpu"),
}
def load_source_data(e: DatasetEnsemble, source: L2SourceDataset) -> tuple[xr.Dataset, gpd.GeoDataFrame]:
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
Args:
e: DatasetEnsemble object.
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
Returns:
xarray.Dataset with the raw data for the specified source.
"""
targets = e._read_target()
# Load the member data lazily to get metadata
ds = e._read_member(source, targets, lazy=False)
return ds, targets

View file

@ -0,0 +1,482 @@
"""General Statistics shared across multiple dashboard pages.
- Dataset statistics: Feature Counts, Class Distributions, Temporal Coverage, all per grid-level-combination
"""
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Literal, cast, get_args
import geopandas as gpd
import pandas as pd
import streamlit as st
import xarray as xr
from stopuhr import stopwatch
import entropice.spatial.grids
import entropice.utils.paths
from entropice.dashboard.utils.loaders import TrainingResult
from entropice.ml.dataset import DatasetEnsemble, bin_values, covcol, taskcol
from entropice.utils.types import Grid, GridLevel, L2SourceDataset, TargetDataset, Task
@dataclass(frozen=True)
class MemberStatistics:
"""Statistics for a specific dataset member."""
feature_count: int # Number of features from this member
variable_names: list[str] # Names of variables from this member
dimensions: dict[str, int] # Dimension name to size mapping
coordinates: list[str] # Names of coordinates from this member
size_bytes: int # Size of this member's data on disk in bytes
@classmethod
def compute(cls, grid: Grid, level: int, member: L2SourceDataset) -> "MemberStatistics":
if member == "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=grid, level=level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
era5_agg = member.split("-")[1]
store = entropice.utils.paths.get_era5_stores(era5_agg, grid=grid, level=level) # ty:ignore[invalid-argument-type]
elif member == "ArcticDEM":
store = entropice.utils.paths.get_arcticdem_stores(grid=grid, level=level)
else:
raise NotImplementedError(f"Member {member} not implemented.")
size_bytes = store.stat().st_size
ds = xr.open_zarr(store, consolidated=False)
# Delete all coordinates which are not in the dimension
for coord in ds.coords:
if coord not in ds.dims:
ds = ds.drop_vars(coord)
n_cols_member = len(ds.data_vars)
for dim in ds.sizes:
if dim != "cell_ids":
n_cols_member *= ds.sizes[dim]
return cls(
feature_count=n_cols_member,
variable_names=list(ds.data_vars),
dimensions=dict(ds.sizes),
coordinates=list(ds.coords),
size_bytes=size_bytes,
)
@dataclass(frozen=True)
class TargetStatistics:
"""Statistics for a specific target dataset."""
training_cells: dict[Task, int] # Number of cells used for training
coverage: dict[Task, float] # Percentage of total cells covered by training data
class_counts: dict[Task, dict[str, int]] # Class name to count mapping per task
class_distribution: dict[Task, dict[str, float]] # Class name to percentage mapping per task
size_bytes: int # Size of the target dataset on disk in bytes
@classmethod
def compute(cls, grid: Grid, level: int, target: TargetDataset, total_cells: int) -> "TargetStatistics":
if target == "darts_rts":
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level)
elif target == "darts_mllabels":
target_store = entropice.utils.paths.get_darts_rts_file(grid=grid, level=level, labels=True)
else:
raise NotImplementedError(f"Target {target} not implemented.")
target_gdf = gpd.read_parquet(target_store)
size_bytes = target_store.stat().st_size
training_cells: dict[Task, int] = {}
training_coverage: dict[Task, float] = {}
class_counts: dict[Task, dict[str, int]] = {}
class_distribution: dict[Task, dict[str, float]] = {}
tasks = cast(list[Task], get_args(Task))
for task in tasks:
task_col = taskcol[task][target]
cov_col = covcol[target]
task_gdf = target_gdf[target_gdf[cov_col]]
training_cells[task] = len(task_gdf)
training_coverage[task] = len(task_gdf) / total_cells * 100
model_labels = task_gdf[task_col].dropna()
if task == "binary":
binned = model_labels.map({False: "No RTS", True: "RTS"}).astype("category")
elif task == "count":
binned = bin_values(model_labels.astype(int), task=task)
elif task == "density":
binned = bin_values(model_labels, task=task)
else:
raise ValueError("Invalid task.")
counts = binned.value_counts()
distribution = counts / counts.sum() * 100
class_counts[task] = counts.to_dict() # ty:ignore[invalid-assignment]
class_distribution[task] = distribution.to_dict()
return TargetStatistics(
training_cells=training_cells,
coverage=training_coverage,
class_counts=class_counts,
class_distribution=class_distribution,
size_bytes=size_bytes,
)
@dataclass(frozen=True)
class DatasetStatistics:
"""Statistics for a potential dataset at a specific grid and level.
These statistics are meant to be easily compute without loading a full dataset into memory.
Further, it is meant to give an overview of all potential potential dataset compositions for a given grid and level.
"""
total_features: int # Total number of features available in the dataset
total_cells: int # Total number of grid cells potentially covered
size_bytes: int # Size of the dataset on disk in bytes
members: dict[L2SourceDataset, MemberStatistics] # Statistics per source dataset member
target: dict[TargetDataset, TargetStatistics] # Statistics per target dataset
@st.cache_data
def load_all_default_dataset_statistics() -> dict[GridLevel, DatasetStatistics]:
dataset_stats: dict[GridLevel, DatasetStatistics] = {}
grid_levels: set[tuple[Grid, int]] = {
("hex", 3),
("hex", 4),
("hex", 5),
("hex", 6),
("healpix", 6),
("healpix", 7),
("healpix", 8),
("healpix", 9),
("healpix", 10),
}
for grid, level in grid_levels:
with stopwatch(f"Loading statistics for grid={grid}, level={level}"):
grid_gdf = entropice.spatial.grids.open(grid, level) # Ensure grid is registered
total_cells = len(grid_gdf)
assert total_cells > 0, "Grid must contain at least one cell."
target_statistics: dict[TargetDataset, TargetStatistics] = {}
targets = cast(list[TargetDataset], get_args(TargetDataset))
for target in targets:
target_statistics[target] = TargetStatistics.compute(
grid=grid, level=level, target=target, total_cells=total_cells
)
member_statistics: dict[L2SourceDataset, MemberStatistics] = {}
members = cast(list[L2SourceDataset], get_args(L2SourceDataset))
for member in members:
member_statistics[member] = MemberStatistics.compute(grid=grid, level=level, member=member)
total_features = sum(ms.feature_count for ms in member_statistics.values())
total_size_bytes = sum(ms.size_bytes for ms in member_statistics.values()) + sum(
ts.size_bytes for ts in target_statistics.values()
)
grid_level: GridLevel = cast(GridLevel, f"{grid}{level}")
dataset_stats[grid_level] = DatasetStatistics(
total_features=total_features,
total_cells=total_cells,
size_bytes=total_size_bytes,
members=member_statistics,
target=target_statistics,
)
return dataset_stats
@dataclass(frozen=True)
class EnsembleMemberStatistics:
n_features: int # Number of features from this member in the ensemble
p_features: float # Percentage of features from this member in the ensemble
n_nanrows: int # Number of rows which contain any NaN
size_bytes: int # Size of this member's data in the ensemble in bytes
@classmethod
def compute(
cls,
dataset: gpd.GeoDataFrame,
member: L2SourceDataset,
n_features: int,
) -> "EnsembleMemberStatistics":
if member == "AlphaEarth":
member_dataset = dataset[dataset.columns.str.startswith("embeddings_")]
elif member == "ArcticDEM":
member_dataset = dataset[dataset.columns.str.startswith("arcticdem_")]
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
era5_cols = dataset.columns.str.startswith("era5_")
cols_with_three_splits = era5_cols & (dataset.columns.str.count("_") == 2)
cols_with_four_splits = era5_cols & (dataset.columns.str.count("_") == 3)
cols_with_summer_winter = era5_cols & (dataset.columns.str.contains("_summer|_winter"))
if member == "ERA5-yearly":
member_dataset = dataset[cols_with_three_splits]
elif member == "ERA5-seasonal":
member_dataset = dataset[cols_with_summer_winter & cols_with_four_splits]
elif member == "ERA5-shoulder":
member_dataset = dataset[~cols_with_summer_winter & cols_with_four_splits]
else:
raise NotImplementedError(f"Member {member} not implemented.")
size_bytes_member = member_dataset.memory_usage(deep=True).sum()
n_features_member = len(member_dataset.columns)
p_features_member = n_features_member / n_features
n_rows_with_nan_member = member_dataset.isna().any(axis=1).sum()
return EnsembleMemberStatistics(
n_features=n_features_member,
p_features=p_features_member,
n_nanrows=n_rows_with_nan_member,
size_bytes=size_bytes_member,
)
@dataclass(frozen=True)
class EnsembleDatasetStatistics:
"""Statistics for a specified composition / ensemble at a specific grid and level.
These statistics are meant to be only computed on user demand, thus by loading a dataset into memory.
That way, the real number of features and cells after NaN filtering can be reported.
"""
n_features: int # Number of features in the dataset
n_cells: int # Number of grid cells covered
n_nanrows: int # Number of rows which contain any NaN
size_bytes: int # Size of the dataset on disk in bytes
members: dict[L2SourceDataset, EnsembleMemberStatistics] # Statistics per source dataset member
@classmethod
def compute(cls, ensemble: DatasetEnsemble, dataset: gpd.GeoDataFrame | None = None) -> "EnsembleDatasetStatistics":
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
# Assert that no column in all-nan
assert not dataset.isna().all("index").any(), "Some input columns are all NaN"
size_bytes = dataset.memory_usage(deep=True).sum()
n_features = len(dataset.columns)
n_cells = len(dataset)
# Number of rows which contain any NaN
n_rows_with_nan = dataset.isna().any(axis=1).sum()
member_statistics: dict[L2SourceDataset, EnsembleMemberStatistics] = {}
for member in ensemble.members:
member_statistics[member] = EnsembleMemberStatistics.compute(
dataset=dataset,
member=member,
n_features=n_features,
)
return cls(
n_features=n_features,
n_cells=n_cells,
n_nanrows=n_rows_with_nan,
size_bytes=size_bytes,
members=member_statistics,
)
@dataclass(frozen=True)
class TrainingDatasetStatistics:
n_samples: int # Total number of samples in the dataset
n_features: int # Number of features
feature_names: list[str] # Names of all features
n_training_samples: int # Number of cells used for training
n_test_samples: int # Number of cells used for testing
train_test_ratio: float # Ratio of training to test samples
class_labels: list[str] # Ordered list of class labels
class_intervals: list[tuple[float, float] | tuple[int, int] | tuple[None, None]] # Min/max raw values per class
n_classes: int # Number of classes
training_class_counts: dict[str, int] # Class counts in training set
training_class_distribution: dict[str, float] # Class percentages in training set
test_class_counts: dict[str, int] # Class counts in test set
test_class_distribution: dict[str, float] # Class percentages in test set
raw_value_min: float # Minimum raw target value
raw_value_max: float # Maximum raw target value
raw_value_mean: float # Mean raw target value
raw_value_median: float # Median raw target value
raw_value_std: float # Standard deviation of raw target values
imbalance_ratio: float # Smallest class count / largest class count (overall)
size_bytes: int # Total memory usage of features in bytes
@classmethod
def compute(
cls, ensemble: DatasetEnsemble, task: Task, dataset: gpd.GeoDataFrame | None = None
) -> "TrainingDatasetStatistics":
dataset = dataset or ensemble.create(filter_target_col=ensemble.covcol)
categorical_dataset = ensemble._cat_and_split(dataset, task=task, device="cpu")
# Sample counts
n_samples = len(categorical_dataset)
n_training_samples = len(categorical_dataset.y.train)
n_test_samples = len(categorical_dataset.y.test)
train_test_ratio = n_training_samples / n_test_samples if n_test_samples > 0 else 0.0
# Feature statistics
n_features = len(categorical_dataset.X.data.columns)
feature_names = list(categorical_dataset.X.data.columns)
size_bytes = categorical_dataset.X.data.memory_usage(deep=True).sum()
# Class information
class_labels = categorical_dataset.y.labels
class_intervals = categorical_dataset.y.intervals
n_classes = len(class_labels)
# Training class distribution
train_y_series = pd.Series(categorical_dataset.y.train)
train_counts = train_y_series.value_counts().sort_index()
training_class_counts = {class_labels[i]: int(train_counts.get(i, 0)) for i in range(n_classes)}
train_total = sum(training_class_counts.values())
training_class_distribution = {
k: (v / train_total * 100) if train_total > 0 else 0.0 for k, v in training_class_counts.items()
}
# Test class distribution
test_y_series = pd.Series(categorical_dataset.y.test)
test_counts = test_y_series.value_counts().sort_index()
test_class_counts = {class_labels[i]: int(test_counts.get(i, 0)) for i in range(n_classes)}
test_total = sum(test_class_counts.values())
test_class_distribution = {
k: (v / test_total * 100) if test_total > 0 else 0.0 for k, v in test_class_counts.items()
}
# Raw value statistics
raw_values = categorical_dataset.y.raw_values
raw_value_min = float(raw_values.min())
raw_value_max = float(raw_values.max())
raw_value_mean = float(raw_values.mean())
raw_value_median = float(raw_values.median())
raw_value_std = float(raw_values.std())
# Imbalance ratio (smallest class / largest class across both splits)
all_counts = list(training_class_counts.values()) + list(test_class_counts.values())
nonzero_counts = [c for c in all_counts if c > 0]
imbalance_ratio = min(nonzero_counts) / max(nonzero_counts) if nonzero_counts else 0.0
return cls(
n_samples=n_samples,
n_features=n_features,
feature_names=feature_names,
n_training_samples=n_training_samples,
n_test_samples=n_test_samples,
train_test_ratio=train_test_ratio,
class_labels=class_labels,
class_intervals=class_intervals,
n_classes=n_classes,
training_class_counts=training_class_counts,
training_class_distribution=training_class_distribution,
test_class_counts=test_class_counts,
test_class_distribution=test_class_distribution,
raw_value_min=raw_value_min,
raw_value_max=raw_value_max,
raw_value_mean=raw_value_mean,
raw_value_median=raw_value_median,
raw_value_std=raw_value_std,
imbalance_ratio=imbalance_ratio,
size_bytes=size_bytes,
)
@dataclass(frozen=True)
class CVMetricStatistics:
best_score: float
mean_score: float
std_score: float
worst_score: float
median_score: float
mean_cv_std: float | None
@classmethod
def compute(cls, result: TrainingResult, metric: str) -> "CVMetricStatistics":
"""Get cross-validation statistics for a metric."""
score_col = f"mean_test_{metric}"
std_col = f"std_test_{metric}"
if score_col not in result.results.columns:
raise ValueError(f"Metric {metric} not found in results.")
best_score = result.results[score_col].max()
mean_score = result.results[score_col].mean()
std_score = result.results[score_col].std()
worst_score = result.results[score_col].min()
median_score = result.results[score_col].median()
mean_cv_std = None
if std_col in result.results.columns:
mean_cv_std = result.results[std_col].mean()
return CVMetricStatistics(
best_score=best_score,
mean_score=mean_score,
std_score=std_score,
worst_score=worst_score,
median_score=median_score,
mean_cv_std=mean_cv_std,
)
@dataclass(frozen=True)
class ParameterSpaceSummary:
parameter: str
type: Literal["Numeric", "Categorical"]
min: float | None
max: float | None
mean: float | None
unique_values: int
@classmethod
def compute(cls, result: TrainingResult, param_col: str) -> "ParameterSpaceSummary":
param_name = param_col.replace("param_", "")
param_values = result.results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
return ParameterSpaceSummary(
parameter=param_name,
type="Numeric",
min=param_values.min(),
max=param_values.max(),
mean=param_values.mean(),
unique_values=param_values.nunique(),
)
else:
unique_vals = param_values.unique()
return ParameterSpaceSummary(
parameter=param_name,
type="Categorical",
min=None,
max=None,
mean=None,
unique_values=len(unique_vals),
)
@dataclass(frozen=True)
class CVResultsStatistics:
metrics: dict[str, CVMetricStatistics]
parameter_summary: list[ParameterSpaceSummary]
@classmethod
def compute(cls, result: TrainingResult) -> "CVResultsStatistics":
"""Get cross-validation statistics for a metric."""
metrics = result.available_metrics
metric_stats: dict[str, CVMetricStatistics] = {}
for metric in metrics:
metric_stats[metric] = CVMetricStatistics.compute(result, metric)
param_cols = [col for col in result.results.columns if col.startswith("param_") and col != "params"]
summary_data = []
for param_col in param_cols:
summary_data.append(ParameterSpaceSummary.compute(result, param_col))
return CVResultsStatistics(metrics=metric_stats, parameter_summary=summary_data)
def metrics_to_dataframe(self) -> pd.DataFrame:
"""Convert metric statistics to a DataFrame."""
data = defaultdict(list)
for metric, stats in self.metrics.items():
data["Metric"].append(metric)
data["Best Score"].append(stats.best_score)
data["Mean Score"].append(stats.mean_score)
data["Std Dev"].append(stats.std_score)
data["Worst Score"].append(stats.worst_score)
data["Median Score"].append(stats.median_score)
data["Mean CV Std Dev"].append(stats.mean_cv_std)
return pd.DataFrame(data)
def parameters_to_dataframe(self) -> pd.DataFrame:
"""Convert parameter summary to a DataFrame."""
return pd.DataFrame([asdict(p) for p in self.parameter_summary])

View file

@ -1,232 +0,0 @@
"""Training utilities for dashboard."""
import pickle
import numpy as np
import pandas as pd
import streamlit as st
import xarray as xr
from entropice.dashboard.utils.data import TrainingResult
def format_metric_name(metric: str) -> str:
"""Format metric name for display.
Args:
metric: Raw metric name (e.g., 'f1_micro', 'precision_macro').
Returns:
Formatted metric name (e.g., 'F1 Micro', 'Precision Macro').
"""
# Split by underscore and capitalize each part
parts = metric.split("_")
# Special handling for F1
formatted_parts = []
for part in parts:
if part.lower() == "f1":
formatted_parts.append("F1")
else:
formatted_parts.append(part.capitalize())
return " ".join(formatted_parts)
def get_available_metrics(results: pd.DataFrame) -> list[str]:
"""Get list of available metrics from results.
Args:
results: DataFrame with CV results.
Returns:
List of metric names (without 'mean_test_' prefix).
"""
score_cols = [col for col in results.columns if col.startswith("mean_test_")]
return [col.replace("mean_test_", "") for col in score_cols]
def load_best_model(result: TrainingResult):
"""Load the best model from a training result.
Args:
result: TrainingResult object.
Returns:
The loaded model object, or None if loading fails.
"""
model_file = result.path / "best_estimator_model.pkl"
if not model_file.exists():
return None
try:
with open(model_file, "rb") as f:
model = pickle.load(f)
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
def load_model_state(result: TrainingResult) -> xr.Dataset | None:
"""Load the model state from a training result.
Args:
result: TrainingResult object.
Returns:
xarray Dataset with model state, or None if not available.
"""
state_file = result.path / "best_estimator_state.nc"
if not state_file.exists():
return None
try:
state = xr.open_dataset(state_file, engine="h5netcdf")
return state
except Exception as e:
st.error(f"Error loading model state: {e}")
return None
def load_predictions(result: TrainingResult) -> pd.DataFrame | None:
"""Load predictions from a training result.
Args:
result: TrainingResult object.
Returns:
DataFrame with predictions, or None if not available.
"""
preds_file = result.path / "predicted_probabilities.parquet"
if not preds_file.exists():
return None
try:
preds = pd.read_parquet(preds_file)
return preds
except Exception as e:
st.error(f"Error loading predictions: {e}")
return None
def get_parameter_space_summary(results: pd.DataFrame) -> pd.DataFrame:
"""Get summary of parameter space explored.
Args:
results: DataFrame with CV results.
Returns:
DataFrame with parameter ranges and statistics.
"""
param_cols = [col for col in results.columns if col.startswith("param_") and col != "params"]
summary_data = []
for param_col in param_cols:
param_name = param_col.replace("param_", "")
param_values = results[param_col].dropna()
if pd.api.types.is_numeric_dtype(param_values):
summary_data.append(
{
"Parameter": param_name,
"Type": "Numeric",
"Min": f"{param_values.min():.2e}",
"Max": f"{param_values.max():.2e}",
"Mean": f"{param_values.mean():.2e}",
"Unique Values": param_values.nunique(),
}
)
else:
unique_vals = param_values.unique()
summary_data.append(
{
"Parameter": param_name,
"Type": "Categorical",
"Min": "-",
"Max": "-",
"Mean": "-",
"Unique Values": len(unique_vals),
}
)
return pd.DataFrame(summary_data)
def get_cv_statistics(results: pd.DataFrame, metric: str) -> dict:
"""Get cross-validation statistics for a metric.
Args:
results: DataFrame with CV results.
metric: Metric name (without 'mean_test_' prefix).
Returns:
Dictionary with CV statistics.
"""
score_col = f"mean_test_{metric}"
std_col = f"std_test_{metric}"
if score_col not in results.columns:
return {}
stats = {
"best_score": results[score_col].max(),
"mean_score": results[score_col].mean(),
"std_score": results[score_col].std(),
"worst_score": results[score_col].min(),
"median_score": results[score_col].median(),
}
if std_col in results.columns:
stats["mean_cv_std"] = results[std_col].mean()
return stats
def prepare_results_for_plotting(results: pd.DataFrame, k_bin_width: int = 40) -> pd.DataFrame:
"""Prepare results dataframe with binned columns for plotting.
Args:
results: DataFrame with CV results.
k_bin_width: Width of bins for initial_K parameter.
Returns:
DataFrame with added binned columns.
"""
results_copy = results.copy()
# Check if we have the parameters
if "param_initial_K" in results.columns:
# Bin initial_K
k_values = results["param_initial_K"].dropna()
if len(k_values) > 0:
k_min = k_values.min()
k_max = k_values.max()
k_bins = range(int(k_min), int(k_max) + k_bin_width, k_bin_width)
results_copy["initial_K_binned"] = pd.cut(results["param_initial_K"], bins=k_bins, right=False)
if "param_eps_cl" in results.columns:
# Create logarithmic bins for eps_cl
eps_cl_values = results["param_eps_cl"].dropna()
if len(eps_cl_values) > 0 and eps_cl_values.min() > 0:
eps_cl_min = eps_cl_values.min()
eps_cl_max = eps_cl_values.max()
eps_cl_bins = np.logspace(np.log10(eps_cl_min), np.log10(eps_cl_max), num=10)
results_copy["eps_cl_binned"] = pd.cut(results["param_eps_cl"], bins=eps_cl_bins)
if "param_eps_e" in results.columns:
# Create logarithmic bins for eps_e
eps_e_values = results["param_eps_e"].dropna()
if len(eps_e_values) > 0 and eps_e_values.min() > 0:
eps_e_min = eps_e_values.min()
eps_e_max = eps_e_values.max()
eps_e_bins = np.logspace(np.log10(eps_e_min), np.log10(eps_e_max), num=10)
results_copy["eps_e_binned"] = pd.cut(results["param_eps_e"], bins=eps_e_bins)
return results_copy

View file

@ -1,147 +1,4 @@
"""Data utilities for Entropice dashboard."""
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import antimeridian
import pandas as pd
import streamlit as st
import toml
import xarray as xr import xarray as xr
from shapely.geometry import shape
import entropice.utils.paths
from entropice.ml.dataset import CategoricalTrainingDataset, DatasetEnsemble
@dataclass
class TrainingResult:
"""Simple wrapper of training result data."""
name: str
path: Path
settings: dict
results: pd.DataFrame
created_at: float
def get_display_name(self, format_type: str = "task_first") -> str:
"""Get formatted display name for the training result.
Args:
format_type: Either 'task_first' (for training analysis) or 'model_first' (for model state)
Returns:
Formatted name string
"""
task = self.settings.get("task", "Unknown").capitalize()
model = self.settings.get("model", "Unknown").upper()
grid = self.settings.get("grid", "Unknown").capitalize()
level = self.settings.get("level", "Unknown")
timestamp = datetime.fromtimestamp(self.created_at).strftime("%Y-%m-%d %H:%M")
if format_type == "model_first":
return f"{model} - {task} - {grid}-{level} ({timestamp})"
else: # task_first
return f"{task} - {model} - {grid}-{level} ({timestamp})"
@classmethod
def from_path(cls, result_path: Path) -> "TrainingResult":
"""Load a TrainingResult from a given result directory path."""
result_file = result_path / "search_results.parquet"
preds_file = result_path / "predicted_probabilities.parquet"
settings_file = result_path / "search_settings.toml"
if not all([result_file.exists(), preds_file.exists(), settings_file.exists()]):
raise FileNotFoundError(f"Missing required files in {result_path}")
created_at = result_path.stat().st_ctime
settings = toml.load(settings_file)["settings"]
results = pd.read_parquet(result_file)
# Name should be "task model grid-level (created_at)"
model = settings.get("model", "Unknown").upper()
name = (
f"{settings.get('task', 'Unknown').capitalize()} -"
f" {model} -"
f" {settings.get('grid', 'Unknown').capitalize()}-{settings.get('level', 'Unknown')}"
f" ({datetime.fromtimestamp(created_at).strftime('%Y-%m-%d %H:%M')})"
)
return cls(
name=name,
path=result_path,
settings=settings,
results=results,
created_at=created_at,
)
def _fix_hex_geometry(geom):
"""Fix hexagon geometry crossing the antimeridian."""
try:
return shape(antimeridian.fix_shape(geom))
except ValueError as e:
st.error(f"Error fixing geometry: {e}")
return geom
@st.cache_data
def load_all_training_results() -> list[TrainingResult]:
"""Load all training results from the results directory."""
results_dir = entropice.utils.paths.RESULTS_DIR
training_results: list[TrainingResult] = []
for result_path in results_dir.iterdir():
if not result_path.is_dir():
continue
try:
training_result = TrainingResult.from_path(result_path)
training_results.append(training_result)
except FileNotFoundError:
continue
# Sort by creation time (most recent first)
training_results.sort(key=lambda tr: tr.created_at, reverse=True)
return training_results
@st.cache_data
def load_all_training_data(e: DatasetEnsemble) -> dict[str, CategoricalTrainingDataset]:
"""Load training data for all three tasks.
Args:
e: DatasetEnsemble object.
Returns:
Dictionary with keys 'binary', 'count', 'density' and CategoricalTrainingDataset values.
"""
return {
"binary": e.create_cat_training_dataset("binary", device="cpu"),
"count": e.create_cat_training_dataset("count", device="cpu"),
"density": e.create_cat_training_dataset("density", device="cpu"),
}
@st.cache_data
def load_source_data(e: DatasetEnsemble, source: str):
"""Load raw data from a specific source (AlphaEarth, ArcticDEM, or ERA5).
Args:
e: DatasetEnsemble object.
source: One of 'AlphaEarth', 'ArcticDEM', 'ERA5-yearly', 'ERA5-seasonal', 'ERA5-shoulder'.
Returns:
xarray.Dataset with the raw data for the specified source.
"""
targets = e._read_target()
# Load the member data lazily to get metadata
ds = e._read_member(source, targets, lazy=False)
return ds, targets
def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray: def _get_feature_importance_array(model_state: xr.Dataset, importance_type: str = "feature_weights") -> xr.DataArray:
@ -203,7 +60,7 @@ def extract_embedding_features(
band=("feature", [f.split("_")[2] for f in embedding_features]), band=("feature", [f.split("_")[2] for f in embedding_features]),
year=("feature", [f.split("_")[3] for f in embedding_features]), year=("feature", [f.split("_")[3] for f in embedding_features]),
) )
embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature") # noqa: PD010 embedding_feature_array = embedding_feature_array.set_index(feature=["agg", "band", "year"]).unstack("feature")
return embedding_feature_array return embedding_feature_array
@ -419,9 +276,9 @@ def extract_era5_features(
era5_features_array = era5_features_array.assign_coords( era5_features_array = era5_features_array.assign_coords(
agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]), agg=("feature", [_extract_agg_name(f) or "none" for f in era5_features]),
) )
era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature") # noqa: PD010 era5_features_array = era5_features_array.set_index(feature=["variable", "time", "agg"]).unstack("feature")
else: else:
era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature") # noqa: PD010 era5_features_array = era5_features_array.set_index(feature=["variable", "time"]).unstack("feature")
return era5_features_array return era5_features_array
@ -472,7 +329,7 @@ def extract_arcticdem_features(
variable=("feature", [_extract_var_name(f) for f in arcticdem_features]), variable=("feature", [_extract_var_name(f) for f in arcticdem_features]),
agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]), agg=("feature", [_extract_agg_name(f) for f in arcticdem_features]),
) )
arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature") # noqa: PD010 arcticdem_feature_array = arcticdem_feature_array.set_index(feature=["variable", "agg"]).unstack("feature")
return arcticdem_feature_array return arcticdem_feature_array
@ -503,16 +360,3 @@ def extract_common_features(model_state: xr.Dataset, importance_type: str = "fea
# Extract the feature importance for common features # Extract the feature importance for common features
common_feature_array = importance_array.sel(feature=common_features) common_feature_array = importance_array.sel(feature=common_features)
return common_feature_array return common_feature_array
def get_members_from_settings(settings: dict) -> list[str]:
"""Extract the list of dataset members used in training from settings.
Args:
settings: Training settings dictionary.
Returns:
List of member dataset names (e.g., ['AlphaEarth', 'ERA5-yearly', 'ERA5-seasonal']).
"""
return settings.get("members", [])

View file

@ -3,7 +3,6 @@
import streamlit as st import streamlit as st
import xarray as xr import xarray as xr
from entropice.dashboard.plots.colors import generate_unified_colormap
from entropice.dashboard.plots.model_state import ( from entropice.dashboard.plots.model_state import (
plot_arcticdem_heatmap, plot_arcticdem_heatmap,
plot_arcticdem_summary, plot_arcticdem_summary,
@ -17,6 +16,7 @@ from entropice.dashboard.plots.model_state import (
plot_era5_time_heatmap, plot_era5_time_heatmap,
plot_top_features, plot_top_features,
) )
from entropice.dashboard.utils.colors import generate_unified_colormap
from entropice.dashboard.utils.data import ( from entropice.dashboard.utils.data import (
extract_arcticdem_features, extract_arcticdem_features,
extract_common_features, extract_common_features,

View file

@ -8,16 +8,17 @@ import pandas as pd
import plotly.express as px import plotly.express as px
import streamlit as st import streamlit as st
from entropice.dashboard.plots.colors import get_palette from entropice.dashboard.utils.colors import get_palette
from entropice.dashboard.utils.data import load_all_training_results from entropice.dashboard.utils.loaders import load_all_training_results
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
from entropice.utils.types import Grid, TargetDataset, Task
# Type definitions for dataset statistics # Type definitions for dataset statistics
class GridConfig(TypedDict): class GridConfig(TypedDict):
"""Grid configuration specification with metadata.""" """Grid configuration specification with metadata."""
grid: str grid: Grid
level: int level: int
grid_name: str grid_name: str
grid_sort_key: str grid_sort_key: str
@ -116,7 +117,7 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache:
Results are cached to avoid redundant computations across different tabs. Results are cached to avoid redundant computations across different tabs.
""" """
# Define all possible grid configurations # Define all possible grid configurations
grid_configs_raw = [ grid_configs_raw: list[tuple[Grid, int]] = [
("hex", 3), ("hex", 3),
("hex", 4), ("hex", 4),
("hex", 5), ("hex", 5),
@ -144,8 +145,8 @@ def load_dataset_analysis_data() -> DatasetAnalysisCache:
# Compute sample counts # Compute sample counts
sample_counts: list[SampleCountData] = [] sample_counts: list[SampleCountData] = []
target_datasets = ["darts_rts", "darts_mllabels"] target_datasets: list[TargetDataset] = ["darts_rts", "darts_mllabels"]
tasks = ["binary", "count", "density"] tasks: list[Task] = ["binary", "count", "density"]
for grid_config in grid_configs: for grid_config in grid_configs:
for target in target_datasets: for target in target_datasets:
@ -810,15 +811,22 @@ def render_experiment_results(training_results):
def render_overview_page(): def render_overview_page():
"""Render the Overview page of the dashboard.""" """Render the Overview page of the dashboard."""
st.title("🏡 Training Results Overview") st.title("🏡 Entropice Overview")
st.markdown(
"""
Welcome to the Entropice Dashboard! This overview page provides a summary of your
training experiments and dataset analyses.
Use the sections below to explore training results, dataset sample counts,
and feature configurations.
"""
)
# Load training results # Load training results
training_results = load_all_training_results() training_results = load_all_training_results()
if not training_results: if not training_results:
st.warning("No training results found. Please run some training experiments first.") st.warning("No training results found. Please run some training experiments first.")
return else:
st.write(f"Found **{len(training_results)}** training result(s)") st.write(f"Found **{len(training_results)}** training result(s)")
st.divider() st.divider()

View file

@ -15,7 +15,7 @@ from entropice.dashboard.plots.source_data import (
render_era5_plots, render_era5_plots,
) )
from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map from entropice.dashboard.plots.training_data import render_all_distribution_histograms, render_spatial_map
from entropice.dashboard.utils.data import load_all_training_data, load_source_data from entropice.dashboard.utils.loaders import load_all_training_data, load_source_data
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
from entropice.spatial import grids from entropice.spatial import grids
@ -60,19 +60,10 @@ def render_training_data_page():
# Members selection # Members selection
st.subheader("Dataset Members") st.subheader("Dataset Members")
# Check if AlphaEarth should be disabled
disable_alphaearth = (grid == "healpix" and level == 10) or (grid == "hex" and level == 6)
all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] all_members = ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
selected_members = [] selected_members = []
for member in all_members: for member in all_members:
if member == "AlphaEarth" and disable_alphaearth:
# Show disabled checkbox with explanation
st.checkbox(
member, value=False, disabled=True, help=f"AlphaEarth is not available for {grid} level {level}"
)
else:
if st.checkbox(member, value=True, help=f"Include {member} in the dataset"): if st.checkbox(member, value=True, help=f"Include {member} in the dataset"):
selected_members.append(member) selected_members.append(member)

View file

@ -6,7 +6,6 @@ Date: October 2025
import warnings import warnings
from collections.abc import Generator from collections.abc import Generator
from typing import Literal
import cudf import cudf
import cuml.cluster import cuml.cluster
@ -25,6 +24,7 @@ from rich.progress import track
from entropice.spatial import grids from entropice.spatial import grids
from entropice.utils import codecs from entropice.utils import codecs
from entropice.utils.paths import get_annual_embeddings_file, get_embeddings_store from entropice.utils.paths import get_annual_embeddings_file, get_embeddings_store
from entropice.utils.types import Grid
# Filter out the GeoDataFrame.swapaxes deprecation warning # Filter out the GeoDataFrame.swapaxes deprecation warning
warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning) warnings.filterwarnings("ignore", message=".*GeoDataFrame.swapaxes.*", category=FutureWarning)
@ -55,11 +55,11 @@ def _batch_grid(grid_gdf: gpd.GeoDataFrame, n_partitions: int) -> Generator[pd.D
@cli.command() @cli.command()
def download(grid: Literal["hex", "healpix"], level: int): def download(grid: Grid, level: int):
"""Extract satellite embeddings from Google Earth Engine and map them to a grid. """Extract satellite embeddings from Google Earth Engine and map them to a grid.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Grid): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
""" """
@ -126,11 +126,11 @@ def download(grid: Literal["hex", "healpix"], level: int):
@cli.command() @cli.command()
def combine_to_zarr(grid: Literal["hex", "healpix"], level: int): def combine_to_zarr(grid: Grid, level: int):
"""Combine yearly embeddings parquet files into a single zarr store. """Combine yearly embeddings parquet files into a single zarr store.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Grid): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
""" """

View file

@ -2,7 +2,6 @@ import datetime
import multiprocessing as mp import multiprocessing as mp
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Literal
import cupy as cp import cupy as cp
import cupy_xarray import cupy_xarray
@ -32,6 +31,7 @@ from entropice.spatial import grids, watermask
from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_grid from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_grid
from entropice.utils import codecs from entropice.utils import codecs
from entropice.utils.paths import get_arcticdem_stores from entropice.utils.paths import get_arcticdem_stores
from entropice.utils.types import Grid
traceback.install(show_locals=True, suppress=[cyclopts]) traceback.install(show_locals=True, suppress=[cyclopts])
pretty.install() pretty.install()
@ -330,7 +330,7 @@ def _open_adem():
@cli.command() @cli.command()
def aggregate(grid: Literal["hex", "healpix"], level: int, concurrent_partitions: int = 20): def aggregate(grid: Grid, level: int, concurrent_partitions: int = 20):
mp.set_start_method("forkserver", force=True) mp.set_start_method("forkserver", force=True)
with ( with (
dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster, dd.LocalCluster(n_workers=1, threads_per_worker=32, memory_limit="20GB") as cluster,

View file

@ -6,8 +6,6 @@ Author: Tobias Hölzer
Date: October 2025 Date: October 2025
""" """
from typing import Literal
import cyclopts import cyclopts
import geopandas as gpd import geopandas as gpd
import pandas as pd import pandas as pd
@ -17,6 +15,7 @@ from stopuhr import stopwatch
from entropice.spatial import grids from entropice.spatial import grids
from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file from entropice.utils.paths import darts_ml_training_labels_repo, dartsl2_cov_file, dartsl2_file, get_darts_rts_file
from entropice.utils.types import Grid
traceback.install() traceback.install()
pretty.install() pretty.install()
@ -25,11 +24,11 @@ cli = cyclopts.App(name="darts-rts")
@cli.command() @cli.command()
def extract_darts_rts(grid: Literal["hex", "healpix"], level: int): def extract_darts_rts(grid: Grid, level: int):
"""Extract RTS labels from DARTS dataset. """Extract RTS labels from DARTS dataset.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Grid): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
""" """
@ -91,7 +90,7 @@ def extract_darts_rts(grid: Literal["hex", "healpix"], level: int):
@cli.command() @cli.command()
def extract_darts_mllabels(grid: Literal["hex", "healpix"], level: int): def extract_darts_mllabels(grid: Grid, level: int):
with stopwatch("Load data"): with stopwatch("Load data"):
grid_gdf = grids.open(grid, level) grid_gdf = grids.open(grid, level)
darts_mllabels = ( darts_mllabels = (

View file

@ -101,6 +101,7 @@ from entropice.spatial.aggregators import _Aggregations, aggregate_raster_into_g
from entropice.spatial.xvec import to_xvec from entropice.spatial.xvec import to_xvec
from entropice.utils import codecs from entropice.utils import codecs
from entropice.utils.paths import FIGURES_DIR, get_era5_stores from entropice.utils.paths import FIGURES_DIR, get_era5_stores
from entropice.utils.types import Grid
traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile]) traceback.install(show_locals=True, suppress=[cyclopts, xr, pd, cProfile])
pretty.install() pretty.install()
@ -579,7 +580,7 @@ def enrich(n_workers: int = 10, monthly: bool = True, yearly: bool = True, daily
@cli.command @cli.command
def viz( def viz(
grid: Literal["hex", "healpix"] | None = None, grid: Grid | None = None,
level: int | None = None, level: int | None = None,
agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly", agg: Literal["daily", "monthly", "yearly", "summer", "winter", "seasonal", "shoulder"] = "monthly",
high_qual: bool = False, high_qual: bool = False,
@ -587,7 +588,7 @@ def viz(
"""Visualize a small overview of ERA5 variables for a given aggregation. """Visualize a small overview of ERA5 variables for a given aggregation.
Args: Args:
grid (Literal["hex", "healpix"], optional): Grid type for spatial representation. grid (Grid, optional): Grid type for spatial representation.
If provided along with level, the ERA5 data will be decoded onto the specified grid. If provided along with level, the ERA5 data will be decoded onto the specified grid.
level (int, optional): Level of the grid for spatial representation. level (int, optional): Level of the grid for spatial representation.
If provided along with grid, the ERA5 data will be decoded onto the specified grid. If provided along with grid, the ERA5 data will be decoded onto the specified grid.
@ -653,7 +654,7 @@ def _correct_longs(ds: xr.Dataset) -> xr.Dataset:
@cli.command @cli.command
def spatial_agg( def spatial_agg(
grid: Literal["hex", "healpix"], grid: Grid,
level: int, level: int,
concurrent_partitions: int = 20, concurrent_partitions: int = 20,
): ):

View file

@ -16,7 +16,7 @@ import hashlib
import json import json
from collections.abc import Generator from collections.abc import Generator
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from functools import cached_property, lru_cache from functools import cached_property
from typing import Literal, TypedDict from typing import Literal, TypedDict
import cupy as cp import cupy as cp
@ -32,6 +32,7 @@ from sklearn import set_config
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import entropice.utils.paths import entropice.utils.paths
from entropice.utils.types import Grid, L2SourceDataset, TargetDataset, Task
traceback.install() traceback.install()
pretty.install() pretty.install()
@ -41,6 +42,27 @@ set_config(array_api_dispatch=True)
sns.set_theme("talk", "whitegrid") sns.set_theme("talk", "whitegrid")
covcol: dict[TargetDataset, str] = {
"darts_rts": "darts_has_coverage",
"darts_mllabels": "dartsml_has_coverage",
}
taskcol: dict[Task, dict[TargetDataset, str]] = {
"binary": {
"darts_rts": "darts_has_rts",
"darts_mllabels": "dartsml_has_rts",
},
"count": {
"darts_rts": "darts_rts_count",
"darts_mllabels": "dartsml_rts_count",
},
"density": {
"darts_rts": "darts_rts_density",
"darts_mllabels": "dartsml_rts_density",
},
}
def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]): def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "shoulder"]):
time_index = pd.DatetimeIndex(df.index.get_level_values("time")) time_index = pd.DatetimeIndex(df.index.get_level_values("time"))
if temporal == "yearly": if temporal == "yearly":
@ -53,10 +75,6 @@ def _get_era5_tempus(df: pd.DataFrame, temporal: Literal["yearly", "seasonal", "
return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_") return time_index.month.map(lambda x: shoulder_seasons.get(x)).str.cat(time_index.year.astype(str), sep="_")
type L2Dataset = Literal["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
type Task = Literal["binary", "count", "density"]
def bin_values( def bin_values(
values: pd.Series, values: pd.Series,
task: Literal["count", "density"], task: Literal["count", "density"],
@ -144,7 +162,7 @@ class DatasetInputs:
@dataclass(frozen=True) @dataclass(frozen=True)
class CategoricalTrainingDataset: class CategoricalTrainingDataset:
dataset: pd.DataFrame dataset: gpd.GeoDataFrame
X: DatasetInputs X: DatasetInputs
y: DatasetLabels y: DatasetLabels
z: pd.Series z: pd.Series
@ -162,12 +180,12 @@ class DatasetStats(TypedDict):
@cyclopts.Parameter("*") @cyclopts.Parameter("*")
@dataclass(frozen=True) @dataclass(frozen=True, kw_only=True)
class DatasetEnsemble: class DatasetEnsemble:
grid: Literal["hex", "healpix"] grid: Grid
level: int level: int
target: Literal["darts_rts", "darts_mllabels"] target: Literal["darts_rts", "darts_mllabels"]
members: list[L2Dataset] = field( members: list[L2SourceDataset] = field(
default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"] default_factory=lambda: ["AlphaEarth", "ArcticDEM", "ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]
) )
dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict) dimension_filters: dict[str, dict[str, list]] = field(default_factory=dict)
@ -186,19 +204,12 @@ class DatasetEnsemble:
@property @property
def covcol(self) -> str: def covcol(self) -> str:
return "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage" return covcol[self.target]
def taskcol(self, task: Task) -> str: def taskcol(self, task: Task) -> str:
if task == "binary": return taskcol[task][self.target]
return "dartsml_has_rts" if self.target == "darts_mllabels" else "darts_has_rts"
elif task == "count":
return "dartsml_rts_count" if self.target == "darts_mllabels" else "darts_rts_count"
elif task == "density":
return "dartsml_rts_density" if self.target == "darts_mllabels" else "darts_rts_density"
else:
raise ValueError(f"Invalid task: {task}")
def _read_member(self, member: L2Dataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset: def _read_member(self, member: L2SourceDataset, targets: gpd.GeoDataFrame, lazy: bool = False) -> xr.Dataset:
if member == "AlphaEarth": if member == "AlphaEarth":
store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level) store = entropice.utils.paths.get_embeddings_store(grid=self.grid, level=self.level)
elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]: elif member in ["ERA5-yearly", "ERA5-seasonal", "ERA5-shoulder"]:
@ -340,8 +351,9 @@ class DatasetEnsemble:
print(f"=== Total number of features in dataset: {stats['total_features']}") print(f"=== Total number of features in dataset: {stats['total_features']}")
@lru_cache(maxsize=1) def create(
def create(self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r") -> pd.DataFrame: self, filter_target_col: str | None = None, cache_mode: Literal["n", "o", "r"] = "r"
) -> gpd.GeoDataFrame:
# n: no cache, o: overwrite cache, r: read cache if exists # n: no cache, o: overwrite cache, r: read cache if exists
cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col) cache_file = entropice.utils.paths.get_dataset_cache(self.id(), subset=filter_target_col)
if cache_mode == "r" and cache_file.exists(): if cache_mode == "r" and cache_file.exists():
@ -425,26 +437,14 @@ class DatasetEnsemble:
print(f"Saved dataset to cache at {cache_file}.") print(f"Saved dataset to cache at {cache_file}.")
yield dataset yield dataset
def create_cat_training_dataset( def _cat_and_split(
self, task: Task, device: Literal["cpu", "cuda", "torch"] self, dataset: gpd.GeoDataFrame, task: Task, device: Literal["cpu", "cuda", "torch"]
) -> CategoricalTrainingDataset: ) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
"""
covcol = "dartsml_has_coverage" if self.target == "darts_mllabels" else "darts_has_coverage"
dataset = self.create(filter_target_col=covcol)
taskcol = self.taskcol(task) taskcol = self.taskcol(task)
valid_labels = dataset[taskcol].notna() valid_labels = dataset[taskcol].notna()
cols_to_drop = {"geometry", taskcol, covcol} cols_to_drop = {"geometry", taskcol, self.covcol}
cols_to_drop |= { cols_to_drop |= {
col col
for col in dataset.columns for col in dataset.columns
@ -505,3 +505,19 @@ class DatasetEnsemble:
z=model_labels, z=model_labels,
split=split, split=split,
) )
def create_cat_training_dataset(
self, task: Task, device: Literal["cpu", "cuda", "torch"]
) -> CategoricalTrainingDataset:
"""Create a categorical dataset for training.
Args:
task (Task): Task type.
device (Literal["cpu", "cuda", "torch"]): Device to load tensors onto.
Returns:
CategoricalTrainingDataset: The prepared categorical training dataset.
"""
dataset = self.create(filter_target_col=self.covcol)
return self._cat_and_split(dataset, task, device)

View file

@ -2,7 +2,6 @@
import pickle import pickle
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Literal
import cupy as cp import cupy as cp
import cyclopts import cyclopts
@ -23,6 +22,7 @@ from xgboost.sklearn import XGBClassifier
from entropice.ml.dataset import DatasetEnsemble from entropice.ml.dataset import DatasetEnsemble
from entropice.ml.inference import predict_proba from entropice.ml.inference import predict_proba
from entropice.utils.paths import get_cv_results_dir from entropice.utils.paths import get_cv_results_dir
from entropice.utils.types import Model, Task
traceback.install() traceback.install()
pretty.install() pretty.install()
@ -50,12 +50,20 @@ _metrics = {
@cyclopts.Parameter("*") @cyclopts.Parameter("*")
@dataclass @dataclass(frozen=True, kw_only=True)
class CVSettings: class CVSettings:
n_iter: int = 2000 n_iter: int = 2000
robust: bool = False robust: bool = False
task: Literal["binary", "count", "density"] = "binary" task: Task = "binary"
model: Literal["espa", "xgboost", "rf", "knn"] = "espa" model: Model = "espa"
@dataclass(frozen=True, kw_only=True)
class TrainingSettings(DatasetEnsemble, CVSettings):
param_grid: dict
cv_splits: int
metrics: list[str]
classes: list[str]
def _create_clf( def _create_clf(
@ -141,7 +149,7 @@ def random_cv(
"""Perform random cross-validation on the training dataset. """Perform random cross-validation on the training dataset.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Grid): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000. n_iter (int, optional): Number of parameter settings that are sampled. Defaults to 2000.
robust (bool, optional): Whether to use robust training. Defaults to False. robust (bool, optional): Whether to use robust training. Defaults to False.
@ -202,18 +210,18 @@ def random_cv(
# Store the search settings # Store the search settings
# First convert the param_grid distributions to a serializable format # First convert the param_grid distributions to a serializable format
param_grid_serializable = _serialize_param_grid(param_grid) param_grid_serializable = _serialize_param_grid(param_grid)
combined_settings = { combined_settings = TrainingSettings(
**asdict(settings), **asdict(settings),
**asdict(dataset_ensemble), **asdict(dataset_ensemble),
"param_grid": param_grid_serializable, param_grid=param_grid_serializable,
"cv_splits": cv.get_n_splits(), cv_splits=cv.get_n_splits(),
"metrics": metrics, metrics=metrics,
"classes": training_data.y.labels, classes=training_data.y.labels,
} )
settings_file = results_dir / "search_settings.toml" settings_file = results_dir / "search_settings.toml"
print(f"Storing search settings to {settings_file}") print(f"Storing search settings to {settings_file}")
with open(settings_file, "w") as f: with open(settings_file, "w") as f:
toml.dump({"settings": combined_settings}, f) toml.dump({"settings": asdict(combined_settings)}, f)
# Store the best estimator model # Store the best estimator model
best_model_file = results_dir / "best_estimator_model.pkl" best_model_file = results_dir / "best_estimator_model.pkl"

View file

@ -33,6 +33,7 @@ from stopuhr import stopwatch
from xdggs.healpix import HealpixInfo from xdggs.healpix import HealpixInfo
from entropice.spatial import grids from entropice.spatial import grids
from entropice.utils.types import Grid
@dataclass(frozen=True) @dataclass(frozen=True)
@ -585,7 +586,7 @@ def aggregate_raster_into_grid(
raster: xr.Dataset | Callable[[], xr.Dataset], raster: xr.Dataset | Callable[[], xr.Dataset],
grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame], grid_gdf: gpd.GeoDataFrame | list[gpd.GeoDataFrame],
aggregations: _Aggregations | Literal["interpolate"], aggregations: _Aggregations | Literal["interpolate"],
grid: Literal["hex", "healpix"], grid: Grid,
level: int, level: int,
n_partitions: int | None = 20, n_partitions: int | None = 20,
concurrent_partitions: int = 5, concurrent_partitions: int = 5,
@ -599,7 +600,7 @@ def aggregate_raster_into_grid(
If a list of GeoDataFrames is provided, each will be processed as a separate partition. If a list of GeoDataFrames is provided, each will be processed as a separate partition.
No further partitioning will be done and the n_partitions argument will be ignored. No further partitioning will be done and the n_partitions argument will be ignored.
aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform. aggregations (_Aggregations | Literal["interpolate"]): The aggregations to perform.
grid (Literal["hex", "healpix"]): The type of grid to use. grid (Grid): The type of grid to use.
level (int): The level of the grid. level (int): The level of the grid.
n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20. n_partitions (int | None, optional): Number of partitions to divide the grid into. Defaults to 20.
concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions. concurrent_partitions (int, optional): Maximum number of worker processes when processing partitions.

View file

@ -5,7 +5,6 @@ Date: 09. June 2025
""" """
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Literal
import cartopy.crs as ccrs import cartopy.crs as ccrs
import cartopy.feature as cfeature import cartopy.feature as cfeature
@ -26,16 +25,17 @@ from stopuhr import stopwatch
from xdggs.healpix import HealpixInfo from xdggs.healpix import HealpixInfo
from entropice.utils.paths import get_grid_file, get_grid_viz_file, watermask_file from entropice.utils.paths import get_grid_file, get_grid_viz_file, watermask_file
from entropice.utils.types import Grid
traceback.install() traceback.install()
pretty.install() pretty.install()
def open(grid: Literal["hex", "healpix"], level: int): def open(grid: Grid, level: int):
"""Open a saved grid from parquet file. """Open a saved grid from parquet file.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Grid): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
Returns: Returns:
@ -47,11 +47,11 @@ def open(grid: Literal["hex", "healpix"], level: int):
return grid return grid
def get_cell_ids(grid: Literal["hex", "healpix"], level: int): def get_cell_ids(grid: Grid, level: int):
"""Get the cell IDs of a saved grid. """Get the cell IDs of a saved grid.
Args: Args:
grid (Literal["hex", "healpix"]): The grid type to use. grid (Grid): The grid type to use.
level (int): The grid level to use. level (int): The grid level to use.
Returns: Returns:
@ -292,7 +292,7 @@ def vizualize_grid(data: gpd.GeoDataFrame, grid: str, level: int) -> plt.Figure:
return fig return fig
def cli(grid: Literal["hex", "healpix"], level: int): def cli(grid: Grid, level: int):
"""CLI entry point.""" """CLI entry point."""
print(f"Creating {grid} grid at level {level}...") print(f"Creating {grid} grid at level {level}...")
if grid == "hex": if grid == "hex":

View file

@ -6,6 +6,8 @@ import os
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from entropice.utils.types import Grid, Task
DATA_DIR = ( DATA_DIR = (
Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice" Path(os.environ.get("FAST_DATA_DIR", None) or os.environ.get("DATA_DIR", None) or "data").resolve() / "entropice"
) )
@ -42,23 +44,23 @@ dartsl2_cov_file = RTS_DIR / "DARTS_NitzeEtAl_v1-2_coverage_2018-2023_level2.par
darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps" darts_ml_training_labels_repo = RTS_LABELS_DIR / "ML_training_labels" / "retrogressive_thaw_slumps"
def _get_gridname(grid: Literal["hex", "healpix"], level: int) -> str: def _get_gridname(grid: Grid, level: int) -> str:
return f"permafrost_{grid}{level}" return f"permafrost_{grid}{level}"
def get_grid_file(grid: Literal["hex", "healpix"], level: int) -> Path: def get_grid_file(grid: Grid, level: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
gridfile = GRIDS_DIR / f"{gridname}_grid.parquet" gridfile = GRIDS_DIR / f"{gridname}_grid.parquet"
return gridfile return gridfile
def get_grid_viz_file(grid: Literal["hex", "healpix"], level: int) -> Path: def get_grid_viz_file(grid: Grid, level: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
vizfile = FIGURES_DIR / f"{gridname}_grid.png" vizfile = FIGURES_DIR / f"{gridname}_grid.png"
return vizfile return vizfile
def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool = False) -> Path: def get_darts_rts_file(grid: Grid, level: int, labels: bool = False) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
if labels: if labels:
rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet" rtsfile = RTS_LABELS_DIR / f"{gridname}_darts-mllabels.parquet"
@ -67,13 +69,13 @@ def get_darts_rts_file(grid: Literal["hex", "healpix"], level: int, labels: bool
return rtsfile return rtsfile
def get_annual_embeddings_file(grid: Literal["hex", "healpix"], level: int, year: int) -> Path: def get_annual_embeddings_file(grid: Grid, level: int, year: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
embfile = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet" embfile = EMBEDDINGS_DIR / f"{gridname}_embeddings-{year}.parquet"
return embfile return embfile
def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path: def get_embeddings_store(grid: Grid, level: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr" embstore = EMBEDDINGS_DIR / f"{gridname}_embeddings.zarr"
return embstore return embstore
@ -81,9 +83,9 @@ def get_embeddings_store(grid: Literal["hex", "healpix"], level: int) -> Path:
def get_era5_stores( def get_era5_stores(
agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily", agg: Literal["daily", "monthly", "summer", "winter", "yearly", "seasonal", "shoulder"] = "daily",
grid: Literal["hex", "healpix"] | None = None, grid: Grid | None = None,
level: int | None = None, level: int | None = None,
): ) -> Path:
if grid is None or level is None: if grid is None or level is None:
(ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True) (ERA5_DIR / "intermediate").mkdir(parents=True, exist_ok=True)
return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr" return ERA5_DIR / "intermediate" / f"{agg}_climate.zarr"
@ -94,9 +96,9 @@ def get_era5_stores(
def get_arcticdem_stores( def get_arcticdem_stores(
grid: Literal["hex", "healpix"] | None = None, grid: Grid | None = None,
level: int | None = None, level: int | None = None,
): ) -> Path:
if grid is None or level is None: if grid is None or level is None:
return DATA_DIR / "arcticdem32m.icechunk.zarr" return DATA_DIR / "arcticdem32m.icechunk.zarr"
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
@ -104,7 +106,7 @@ def get_arcticdem_stores(
return aligned_path return aligned_path
def get_train_dataset_file(grid: Literal["hex", "healpix"], level: int) -> Path: def get_train_dataset_file(grid: Grid, level: int) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet" dataset_file = TRAINING_DIR / f"{gridname}_train_dataset.parquet"
return dataset_file return dataset_file
@ -122,9 +124,9 @@ def get_dataset_cache(eid: str, subset: str | None = None, batch: tuple[int, int
def get_cv_results_dir( def get_cv_results_dir(
name: str, name: str,
grid: Literal["hex", "healpix"], grid: Grid,
level: int, level: int,
task: Literal["binary", "count", "density"], task: Task,
) -> Path: ) -> Path:
gridname = _get_gridname(grid, level) gridname = _get_gridname(grid, level)
now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

View file

@ -0,0 +1,11 @@
"""Shared types used across the entropice codebase."""
from typing import Literal
type Grid = Literal["hex", "healpix"]
type GridLevel = Literal["hex3", "hex4", "hex5", "hex6", "healpix6", "healpix7", "healpix8", "healpix9", "healpix10"]
type TargetDataset = Literal["darts_rts", "darts_mllabels"]
type L0SourceDataset = Literal["ArcticDEM", "ERA5", "AlphaEarth"]
type L2SourceDataset = Literal["ArcticDEM", "ERA5-shoulder", "ERA5-seasonal", "ERA5-yearly", "AlphaEarth"]
type Task = Literal["binary", "count", "density"]
type Model = Literal["espa", "xgboost", "rf", "knn"]